Merge pull request #82 from quantopian/new_world_order

New world order
This commit is contained in:
Scott Sanderson
2012-08-08 19:28:34 -07:00
40 changed files with 3263 additions and 1671 deletions
+381
View File
@@ -0,0 +1,381 @@
import zmq
import pytz
from pprint import pformat as pf
from datetime import datetime, timedelta
from unittest2 import TestCase
from collections import defaultdict
from zipline.gens.composites import date_sorted_sources, merged_transforms
from zipline.core.devsimulator import AddressAllocator
from zipline.gens.transform import Passthrough, StatefulTransform
from zipline.gens.mavg import MovingAverage
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
from zipline.utils.factory import create_trading_environment
from zipline.test_algorithms import TestAlgorithm
from zipline.utils.test_utils import (
setup_logger,
teardown_logger,
create_monitor,
launch_monitor
)
from zipline.core import Component
from zipline.protocol import (
DATASOURCE_FRAME,
DATASOURCE_UNFRAME,
FEED_FRAME,
FEED_UNFRAME,
MERGE_FRAME,
MERGE_UNFRAME,
SIMULATION_STYLE,
PERF_FRAME,
BT_UPDATE_UNFRAME
)
from zipline.gens.tradegens import SpecificEquityTrades
import logbook
log = logbook.Logger('ComponentTestCase')
allocator = AddressAllocator(1000)
class ComponentTestCase(TestCase):
leased_sockets = defaultdict(list)
def setUp(self):
self.zipline_test_config = {
'allocator' : allocator,
'sid' : 133,
'devel' : False,
'results_socket' : allocator.lease(1)[0],
'simulation_style' : SIMULATION_STYLE.FIXED_SLIPPAGE
}
self.ctx = zmq.Context()
setup_logger(self)
count = 250
filter = [2,3]
#Set up source a. One minute between events.
args_a = tuple()
kwargs_a = {
'count' : 2*count,
'sids' : [1,2,3],
'start' : datetime(2002,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(hours = 6),
'filter' : filter
}
self.source_a = SpecificEquityTrades(*args_a, **kwargs_a)
#Set up source b. Two minutes between events.
args_b = tuple()
kwargs_b = {
'count' : count,
'sids' : [2,3,4],
'start' : datetime(2002,1,3,14, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 5),
'filter' : filter
}
self.source_b = SpecificEquityTrades(*args_b, **kwargs_b)
self.environment = create_trading_environment(year = 2002)
def tearDown(self):
teardown_logger(self)
def test_source(self):
monitor = create_monitor(allocator)
socket_uri = allocator.lease(1)[0]
count = 100
filter = [1,2,3,4]
#Set up source a. One minute between events.
args_a = tuple()
kwargs_a = {
'sids' : [1,2],
'start' : datetime(2012,6,6,0,tzinfo=pytz.utc),
'delta' : timedelta(minutes = 1),
'filter' : filter,
'count' : count
}
trade_gen = SpecificEquityTrades(*args_a, **kwargs_a)
comp_a = Component(
trade_gen,
monitor,
socket_uri,
DATASOURCE_FRAME,
DATASOURCE_UNFRAME,
"source_a"
)
mon_proc = launch_monitor(monitor)
for event in comp_a:
log.info(event)
# wait for the sending process to exit
comp_a.proc.join()
mon_proc.join()
def test_sort(self):
monitor = create_monitor(allocator)
socket_uris = allocator.lease(3)
count = 100
filter = [1,2,3,4]
#Set up source a. One minute between events.
args_a = tuple()
kwargs_a = {
'sids' : [1,2],
'start' : datetime(2012,6,6,0,tzinfo=pytz.utc),
'delta' : timedelta(minutes = 1),
'filter' : filter,
'count' : count
}
trade_gen_a = SpecificEquityTrades(*args_a, **kwargs_a)
#Set up source b. Two minutes between events.
args_b = tuple()
kwargs_b = {
'sids' : [2],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 1),
'filter' : filter,
'count' : count
}
trade_gen_b = SpecificEquityTrades(*args_b, **kwargs_b)
#Set up source c. Three minutes between events.
args_c = tuple()
kwargs_c = {
'sids' : [3],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 1),
'filter' : filter,
'count' : count
}
trade_gen_c = SpecificEquityTrades(*args_c, **kwargs_c)
comp_a = Component(
trade_gen_a,
monitor,
socket_uris[0],
DATASOURCE_FRAME,
DATASOURCE_UNFRAME,
trade_gen_a.get_hash()
)
comp_b = Component(
trade_gen_b,
monitor,
socket_uris[1],
DATASOURCE_FRAME,
DATASOURCE_UNFRAME,
trade_gen_b.get_hash()
)
comp_c = Component(
trade_gen_c,
monitor,
socket_uris[2],
DATASOURCE_FRAME,
DATASOURCE_UNFRAME,
trade_gen_c.get_hash()
)
sources = [comp_a, comp_b, comp_c]
sorted_out = date_sorted_sources(*sources)
mon_proc = launch_monitor(monitor)
prev = None
sort_count = 0
for msg in sorted_out:
if prev:
self.assertTrue(msg.dt >= prev.dt, \
"Messages should be in date ascending order")
prev = msg
sort_count += 1
self.assertEqual(count*3, sort_count)
# wait for processes to finish
comp_a.proc.join()
comp_b.proc.join()
comp_c.proc.join()
mon_proc.join()
def test_full(self):
monitor = create_monitor(allocator)
# ------------------------
# Run sources in dedicated processes
comp_a = Component(
self.source_a,
monitor,
allocator.lease(1)[0],
DATASOURCE_FRAME,
DATASOURCE_UNFRAME,
self.source_a.get_hash()
)
comp_b = Component(
self.source_b,
monitor,
allocator.lease(1)[0],
DATASOURCE_FRAME,
DATASOURCE_UNFRAME,
self.source_b.get_hash()
)
# Date sort the sources, and run the sort in a dedicated
# process
sources = [comp_a, comp_b]
sorted_out = date_sorted_sources(*sources)
sorted = Component(
sorted_out,
monitor,
allocator.lease(1)[0],
FEED_FRAME,
FEED_UNFRAME,
"sort"
)
passthrough = StatefulTransform(Passthrough)
mavg_price = StatefulTransform(
MovingAverage,
['price'],
market_aware = False,
delta=timedelta(minutes = 20)
)
merged_gen = merged_transforms(sorted, passthrough, mavg_price)
merged = Component(
merged_gen,
monitor,
allocator.lease(1)[0],
MERGE_FRAME,
MERGE_UNFRAME,
"merge"
)
algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
style = SIMULATION_STYLE.FIXED_SLIPPAGE
trading_client = tsc(algo, self.environment, style)
tsc_gen = trading_client.simulate(merged)
tsc_comp = Component(
tsc_gen,
monitor,
allocator.lease(1)[0],
PERF_FRAME,
BT_UPDATE_UNFRAME,
"tsc"
)
mon_proc = launch_monitor(monitor)
for message in tsc_comp:
log.info(pf(message))
# wait for processes to finish
comp_a.proc.join()
comp_b.proc.join()
sorted.proc.join()
merged.proc.join()
tsc_comp.proc.join()
mon_proc.join()
return
def test_single_thread(self):
#Set up source c. Three minutes between events.
sorted = date_sorted_sources(self.source_a, self.source_b)
passthrough = StatefulTransform(Passthrough)
mavg_price = StatefulTransform(
MovingAverage,
['price'],
market_aware=False,
delta=timedelta(minutes = 20),
)
merged = merged_transforms(sorted, passthrough, mavg_price)
algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
style = SIMULATION_STYLE.FIXED_SLIPPAGE
trading_client = tsc(algo, self.environment, style)
for message in trading_client.simulate(merged):
log.info(pf(message))
def test_compound(self):
monitor = create_monitor(allocator)
sorted_out = date_sorted_sources(self.source_a, self.source_b)
sorted = Component(
sorted_out,
monitor,
allocator.lease(1)[0],
FEED_FRAME,
FEED_UNFRAME,
"feed"
)
passthrough = StatefulTransform(Passthrough)
mavg_price = StatefulTransform(
MovingAverage,
['price'],
market_aware = False,
delta=timedelta(minutes = 20)
)
merged_gen = merged_transforms(sorted, passthrough, mavg_price)
merged = Component(
merged_gen,
monitor,
allocator.lease(1)[0],
MERGE_FRAME,
MERGE_UNFRAME,
"merge"
)
algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
style = SIMULATION_STYLE.FIXED_SLIPPAGE
trading_client = tsc(algo, self.environment, style)
tsc_gen = trading_client.simulate(merged)
mon_proc = launch_monitor(monitor)
for message in tsc_gen:
log.info(pf(message))
# wait for processes to finish
sorted.proc.join()
merged.proc.join()
mon_proc.join()
return
+57 -109
View File
@@ -7,30 +7,30 @@ from zipline.test_algorithms import ExceptionAlgorithm, DivByZeroAlgorithm
from zipline.finance.trading import SIMULATION_STYLE
from zipline.core.devsimulator import AddressAllocator
from zipline.lines import SimulatedTrading
from zipline.gens.transform import StatefulTransform
from zipline.utils.test_utils import \
drain_zipline, \
check, \
setup_logger, \
teardown_logger
teardown_logger, \
ExceptionSource, \
ExceptionTransform
DEFAULT_TIMEOUT = 15 # seconds
EXTENDED_TIMEOUT = 90
allocator = AddressAllocator(1000)
class ExceptionTestCase(TestCase):
leased_sockets = defaultdict(list)
def setUp(self):
self.zipline_test_config = {
'allocator' : allocator,
'sid' : 133,
'devel' : False,
'results_socket' : allocator.lease(1)[0],
'simulation_style' : SIMULATION_STYLE.FIXED_SLIPPAGE
'sid' : 133,
'results_socket_uri' : allocator.lease(1)[0],
'simulation_style' : SIMULATION_STYLE.FIXED_SLIPPAGE
}
self.ctx = zmq.Context()
setup_logger(self)
@@ -39,6 +39,39 @@ class ExceptionTestCase(TestCase):
self.ctx.term()
teardown_logger(self)
def test_datasource_exception(self):
self.zipline_test_config['trade_source'] = ExceptionSource()
zipline = SimulatedTrading.create_test_zipline(
**self.zipline_test_config
)
output, _ = drain_zipline(self, zipline)
assert len(output) == 1
assert output[0]['prefix'] == 'EXCEPTION'
message = output[0]['payload']
for field in ['date', 'message', 'name', 'stack']:
assert field in message.keys()
assert message['message'] == 'integer division or modulo by zero'
assert message['name'] == 'ZeroDivisionError'
def test_tranform_exception(self):
exc_tnfm = StatefulTransform(ExceptionTransform)
self.zipline_test_config['transforms'] = [exc_tnfm]
zipline = SimulatedTrading.create_test_zipline(
**self.zipline_test_config
)
output, _ = drain_zipline(self, zipline)
assert len(output) == 1
assert output[0]['prefix'] == 'EXCEPTION'
message = output[0]['payload']
for field in ['date', 'message', 'name', 'stack']:
assert field in message.keys()
assert message['message'] == 'An assertion message'
assert message['name'] == 'AssertionError'
def test_exception_in_init(self):
# Simulation
# ----------
@@ -52,19 +85,17 @@ class ExceptionTestCase(TestCase):
**self.zipline_test_config
)
output, _ = drain_zipline(self, zipline)
self.assertEqual(len(output), 1)
self.assertEqual(output[-1]['prefix'], 'EXCEPTION')
payload = output[-1]['payload']
self.assertTrue(payload['date'])
del payload['date']
check(self, payload, INITIALIZE_TB)
self.assertTrue(zipline.sim.ready())
self.assertFalse(zipline.sim.exception)
self.assertEqual(payload['message'],'Algo exception in initialize')
self.assertEqual(payload['name'],'Exception')
# make sure our path shortening is working
self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py')
self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py')
def test_exception_in_handle_data(self):
# Simulation
# ----------
self.zipline_test_config['algorithm'] = \
@@ -78,16 +109,15 @@ class ExceptionTestCase(TestCase):
)
output, _ = drain_zipline(self, zipline)
self.assertEqual(len(output), 1)
self.assertEqual(output[-1]['prefix'], 'EXCEPTION')
payload = output[-1]['payload']
self.assertTrue(payload['date'])
del payload['date']
check(self, payload, HANDLE_DATA_TB)
self.assertTrue(zipline.sim.ready())
self.assertFalse(zipline.sim.exception)
self.assertEqual(payload['message'],'Algo exception in handle_data')
self.assertEqual(payload['name'],'Exception')
# make sure our path shortening is working
self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py')
self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py')
def test_zerodivision_exception_in_handle_data(self):
@@ -103,95 +133,13 @@ class ExceptionTestCase(TestCase):
)
output, _ = drain_zipline(self, zipline)
self.assertEqual(len(output), 5)
self.assertEqual(output[-1]['prefix'], 'EXCEPTION')
payload = output[-1]['payload']
self.assertTrue(payload['date'])
del payload['date']
check(self, payload, ZERO_DIV_TB)
self.assertTrue(zipline.sim.ready())
self.assertFalse(zipline.sim.exception)
# TODO:
# - define more zipline failure modes: exception in other
# components, exception in Monitor, etc. write tests
# for those scenarios.
INITIALIZE_TB =\
{'message': 'Algo exception in initialize',
'name': 'Exception',
'stack': [{'filename': '/zipline/core/component.py', 'line': 'self._run()', 'lineno': 204, 'method': 'run'},
{'filename': '/zipline/core/component.py', 'line': 'self.loop()', 'lineno': 195, 'method': '_run'},
{'filename': '/zipline/core/component.py', 'line': 'self.do_work()', 'lineno': 235, 'method': 'loop'},
{'filename': '/zipline/components/tradesimulation.py',
'line': 'self.initialize_algo()',
'lineno': 97,
'method': 'do_work'},
{'filename': '/zipline/components/tradesimulation.py',
'line': 'self.do_op(self.algorithm.initialize)',
'lineno': 80,
'method': 'initialize_algo'},
{'filename': '/zipline/components/tradesimulation.py',
'line': 'callable_op(*args, **kwargs)',
'lineno': 210,
'method': 'do_op'},
{'filename': '/zipline/test_algorithms.py',
'line': 'raise Exception("Algo exception in initialize")',
'lineno': 166,
'method': 'initialize'}]}
HANDLE_DATA_TB =\
{
'message': 'Algo exception in handle_data',
'name': 'Exception',
'stack': [{'filename': '/zipline/core/component.py', 'line': 'self._run()', 'lineno': 204, 'method': 'run'},
{'filename': '/zipline/core/component.py', 'line': 'self.loop()', 'lineno': 195, 'method': '_run'},
{'filename': '/zipline/core/component.py', 'line': 'self.do_work()', 'lineno': 235, 'method': 'loop'},
{'filename': '/zipline/components/tradesimulation.py',
'line': 'self.process_event(event)',
'lineno': 116,
'method': 'do_work'},
{'filename': '/zipline/components/tradesimulation.py',
'line': 'self.run_algorithm()',
'lineno': 164,
'method': 'process_event'},
{'filename': '/zipline/components/tradesimulation.py',
'line': 'self.do_op(self.algorithm.handle_data, data)',
'lineno': 186,
'method': 'run_algorithm'},
{'filename': '/zipline/components/tradesimulation.py',
'line': 'callable_op(*args, **kwargs)',
'lineno': 210,
'method': 'do_op'},
{'filename': '/zipline/test_algorithms.py',
'line': 'raise Exception("Algo exception in handle_data")',
'lineno': 187,
'method': 'handle_data'}]}
ZERO_DIV_TB= \
{'message': 'integer division or modulo by zero',
'name': 'ZeroDivisionError',
'stack': [{'filename': '/zipline/core/component.py', 'line': 'self._run()', 'lineno': 204, 'method': 'run'},
{'filename': '/zipline/core/component.py', 'line': 'self.loop()', 'lineno': 195, 'method': '_run'},
{'filename': '/zipline/core/component.py', 'line': 'self.do_work()', 'lineno': 235, 'method': 'loop'},
{'filename': '/zipline/components/tradesimulation.py',
'line': 'self.process_event(event)',
'lineno': 116,
'method': 'do_work'},
{'filename': '/zipline/components/tradesimulation.py',
'line': 'self.run_algorithm()',
'lineno': 164,
'method': 'process_event'},
{'filename': '/zipline/components/tradesimulation.py',
'line': 'self.do_op(self.algorithm.handle_data, data)',
'lineno': 186,
'method': 'run_algorithm'},
{'filename': '/zipline/components/tradesimulation.py',
'line': 'callable_op(*args, **kwargs)',
'lineno': 210,
'method': 'do_op'},
{'filename': '/zipline/test_algorithms.py', 'line': '5/0', 'lineno': 218, 'method': 'handle_data'}]}
self.assertEqual(payload['message'],'integer division or modulo by zero')
self.assertEqual(payload['name'],'ZeroDivisionError')
# make sure our path shortening is working
self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py')
self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py')
-233
View File
@@ -1,233 +0,0 @@
from unittest2 import TestCase
from itertools import cycle, chain
from datetime import datetime, timedelta
from collections import deque
from zipline import ndict
from zipline.gens.sort import \
date_sort, \
ready, \
done, \
queue_is_ready,\
queue_is_done
from zipline.gens.utils import hash_args, alternate
from zipline.gens.tradegens import date_gen, SpecificEquityTrades
from zipline.gens.composites import date_sorted_sources
import zipline.protocol as zp
class HelperTestCase(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_individual_queue_logic(self):
queue = deque()
# Empty queues are neither done nor ready.
assert not queue_is_ready(queue)
assert not queue_is_done(queue)
queue.append(to_dt('foo'))
assert queue_is_ready(queue)
assert not queue_is_done(queue)
queue.appendleft(to_dt('DONE'))
assert queue_is_ready(queue)
# Checking done when we have a message after done will trip an assert.
self.assertRaises(AssertionError, queue_is_done, queue)
queue.pop()
assert queue_is_ready(queue)
assert queue_is_done(queue)
def test_pop_logic(self):
sources = {}
ids = ['a', 'b', 'c']
for id in ids:
sources[id] = deque()
assert not ready(sources)
assert not done(sources)
# All sources must have a message to be ready/done
sources['a'].append(to_dt("datetime"))
assert not ready(sources)
assert not done(sources)
sources['a'].pop()
for id in ids:
sources[id].append(to_dt("datetime"))
assert ready(sources)
assert not done(sources)
for id in ids:
sources[id].appendleft(to_dt("DONE"))
# ["DONE", message] will trip an assert in queue_is_done.
assert ready(sources)
self.assertRaises(AssertionError, done, sources)
for id in ids:
sources[id].pop()
assert ready(sources)
assert done(sources)
class DateSortTestCase(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def run_date_sort(self, events, expected, source_ids):
"""
Take a list of events, their source_ids, and an expected sorting.
Assert that date_sort's output agrees with expected.
"""
sort_gen = date_sort(events, source_ids)
l = list(sort_gen)
assert l == expected
def test_single_source(self):
source_ids = ['a']
# 100 events, increasing by a minute at a time.
type = zp.DATASOURCE_TYPE.TRADE
dates = list(date_gen(count = 100))
dates.append("DONE")
# [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)]
event_args = zip(cycle(source_ids), iter(dates), cycle([type]))
# Turn event_args into proper events.
events = [mock_data_unframe(*args) for args in event_args]
# We don't expected Feed to yield the last event.
expected = events[:-1]
event_gen = (e for e in events)
self.run_date_sort(event_gen, expected, source_ids)
def test_multi_source(self):
source_ids = ['a', 'b']
type = zp.DATASOURCE_TYPE.TRADE
# Set up source 'a'. Outputs 20 events with 2 minute deltas.
delta_a = timedelta(minutes = 2)
dates_a = list(date_gen(delta = delta_a, count = 20))
dates_a.append("DONE")
events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type]))
events_a = [mock_data_unframe(*args) for args in events_a_args]
# Set up source 'b'. Outputs 10 events with 1 minute deltas.
delta_b = timedelta(minutes = 1)
dates_b = list(date_gen(delta = delta_b, count = 10))
dates_b.append("DONE")
events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type]))
events_b = [mock_data_unframe(*args) for args in events_b_args]
# The expected output is all non-DONE events in both a and b,
# sorted first by dt and then by source_id.
non_dones = events_a[:-1] + events_b[:-1]
expected = sorted(non_dones, compare_by_dt_source_id)
# Alternating between a and b.
interleaved = alternate(iter(events_a), iter(events_b))
self.run_date_sort(interleaved, expected, source_ids)
# All of a, then all of b.
sequential = chain(iter(events_a), iter(events_b))
self.run_date_sort(sequential, expected, source_ids)
def test_sorted_sources(self):
filter = [1,2]
#Set up source a. One hour between events.
args_a = tuple()
kwargs_a = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(hours = 1),
'filter' : filter
}
#Set up source b. One day between events.
args_b = tuple()
kwargs_b = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(days = 1),
'filter' : filter
}
#Set up source c. One minute between events.
args_c = tuple()
kwargs_c = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
# Set up source d. This should produce no events because the
# internal sids don't match the filter.
args_d = tuple()
kwargs_d = {'sids' : [3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
sources = (SpecificEquityTrades,) * 4
source_args = (args_a, args_b, args_c, args_d)
source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d)
# Generate our expected source_ids.
zip_args = zip(source_args, source_kwargs)
expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs)
for args, kwargs in zip_args]
# Pipe our sources into sort.
sort_out = date_sorted_sources(sources, source_args, source_kwargs)
# Read all the values from sort and assert that they arrive in
# the correct sorting with the expected hash values.
to_list = list(sort_out)
copy = to_list[:]
for e in to_list:
# All events should match one of our expected source_ids.
assert e.source_id in expected_ids
# But none of them should match source_d.
assert e.source_id != hash_args(*args_d, **kwargs_d)
expected = sorted(copy, compare_by_dt_source_id)
assert to_list == expected
def mock_data_unframe(source_id, dt, type):
event = ndict()
event.source_id = source_id
event.dt = dt
event.type = type
return event
def to_dt(val):
return ndict({'dt': val})
def compare_by_dt_source_id(x,y):
if x.dt < y.dt:
return -1
elif x.dt > y.dt:
return 1
elif x.source_id < y.source_id:
return -1
elif x.source_id > y.source_id:
return 1
else:
return 0
+8 -17
View File
@@ -11,7 +11,6 @@ from collections import defaultdict
from nose.tools import timed
import zipline.utils.factory as factory
import zipline.protocol as zp
from zipline.test_algorithms import TestAlgorithm
from zipline.finance.trading import TradingEnvironment
@@ -19,10 +18,9 @@ from zipline.core.devsimulator import AddressAllocator
from zipline.lines import SimulatedTrading
from zipline.finance.performance import PerformanceTracker
from zipline.utils.protocol_utils import ndict
from zipline.finance.trading import TransactionSimulator, SIMULATION_STYLE
from zipline.finance.trading import TransactionSimulator
from zipline.utils.test_utils import \
drain_zipline, \
check, \
setup_logger, \
teardown_logger,\
assert_single_position
@@ -39,10 +37,8 @@ class FinanceTestCase(TestCase):
def setUp(self):
self.zipline_test_config = {
'allocator' : allocator,
'sid' : 133,
#'devel' : True,
'results_socket' : allocator.lease(1)[0]
'sid' : 133,
'results_socket_uri' : allocator.lease(1)[0]
}
self.ctx = zmq.Context()
@@ -60,7 +56,7 @@ class FinanceTestCase(TestCase):
trading_environment
)
prev = None
for trade in trade_source.event_list:
for trade in trade_source:
if prev:
self.assertTrue(trade.dt > prev.dt)
prev = trade
@@ -123,7 +119,6 @@ class FinanceTestCase(TestCase):
self.zipline_test_config['order_count'] = 100
self.zipline_test_config['trade_count'] = 200
zipline = SimulatedTrading.create_test_zipline(**self.zipline_test_config)
assert_single_position(self, zipline)
#@timed(DEFAULT_TIMEOUT)
@@ -148,9 +143,6 @@ class FinanceTestCase(TestCase):
)
output, transaction_count = drain_zipline(self, zipline)
self.assertTrue(zipline.sim.ready())
self.assertFalse(zipline.sim.exception)
#check that the algorithm received no events
self.assertEqual(
0,
@@ -301,12 +293,12 @@ class FinanceTestCase(TestCase):
# if present, expect transaction amounts to match orders exactly.
complete_fill = params.get('complete_fill')
sid = 1
trading_environment = factory.create_trading_environment()
trade_sim = TransactionSimulator()
trade_sim = TransactionSimulator([sid])
price = [10.1] * trade_count
volume = [100] * trade_count
start_date = trading_environment.first_open
sid = 1
generated_trades = factory.create_trade_history(
sid,
@@ -330,7 +322,7 @@ class FinanceTestCase(TestCase):
'dt' : order_date
})
trade_sim.add_open_order(order)
trade_sim.place_order(order)
order_date = order_date + order_interval
# move after market orders to just after market next
@@ -353,14 +345,13 @@ class FinanceTestCase(TestCase):
self.assertEqual(order.amount, order_amount * alternator**i)
tracker = PerformanceTracker(trading_environment)
tracker = PerformanceTracker(trading_environment, [sid])
# this approximates the loop inside TradingSimulationClient
transactions = []
for trade in generated_trades:
if trade_delay:
trade.dt = trade.dt + trade_delay
txn = trade_sim.apply_trade_to_open_orders(trade)
if txn:
transactions.append(txn)
+7 -5
View File
@@ -1,7 +1,7 @@
from zipline.utils.test_utils import setup_logger, teardown_logger
from unittest2 import TestCase, skip
from zipline.core.monitor import Controller
from zipline.core.monitor import Monitor
class TestMonitor(TestCase):
def setUp(self):
@@ -14,13 +14,15 @@ class TestMonitor(TestCase):
def test_init(self):
pub_socket = 'tcp://127.0.0.1:5000'
route_socket = 'tcp://127.0.0.1:5001'
exception_socket = 'tcp://127.0.0.1:5002'
con = Controller(pub_socket, route_socket)
con.manage([])
mon = Monitor(pub_socket, route_socket, exception_socket)
mon.manage([])
def test_init_topology(self):
pub_socket = 'tcp://127.0.0.1:5000'
route_socket = 'tcp://127.0.0.1:5001'
exception_socket = 'tcp://127.0.0.1:5002'
con = Controller(pub_socket, route_socket, )
con.manage([ 'a', 'b', 'c', 'd' ])
mon = Monitor(pub_socket, route_socket, exception_socket)
mon.manage([ 'a', 'b', 'c', 'd' ])
+4 -1
View File
@@ -543,7 +543,10 @@ shares in position"
self.trading_environment.capital_base = 1000.0
self.trading_environment.frame_index = ['sid', 'volume', 'dt', \
'price', 'changed']
perf_tracker = perf.PerformanceTracker(self.trading_environment)
perf_tracker = perf.PerformanceTracker(
self.trading_environment,
[sid, sid2]
)
for event in trade_history:
#create a transaction for all but
-4
View File
@@ -1,8 +1,6 @@
"""
Test the FRAME/UNFRAME functions in the sequence expected from ziplines.
"""
import pytz
from unittest2 import TestCase
from datetime import datetime, timedelta
from collections import defaultdict
@@ -10,10 +8,8 @@ from collections import defaultdict
from nose.tools import timed
import zipline.utils.factory as factory
from zipline.utils import logger
import zipline.protocol as zp
from zipline.finance.sources import SpecificEquityTrades
DEFAULT_TIMEOUT = 5 # seconds
+259
View File
@@ -0,0 +1,259 @@
import pytz
from unittest2 import TestCase
from itertools import cycle, chain, izip, izip_longest
from datetime import datetime, timedelta
from collections import deque
from zipline import ndict
from zipline.gens.sort import \
date_sort, \
ready, \
done, \
queue_is_ready,\
queue_is_done
from zipline.gens.utils import hash_args, alternate, done_message
from zipline.gens.tradegens import date_gen, SpecificEquityTrades
from zipline.gens.composites import date_sorted_sources
import zipline.protocol as zp
class HelperTestCase(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_individual_queue_logic(self):
queue = deque()
# Empty queues are neither done nor ready.
assert not queue_is_ready(queue)
assert not queue_is_done(queue)
queue.append(to_dt('foo'))
assert queue_is_ready(queue)
assert not queue_is_done(queue)
queue.appendleft(to_dt('DONE'))
assert queue_is_ready(queue)
# Checking done when we have a message after done will trip an assert.
self.assertRaises(AssertionError, queue_is_done, queue)
queue.pop()
assert queue_is_ready(queue)
assert queue_is_done(queue)
def test_pop_logic(self):
sources = {}
ids = ['a', 'b', 'c']
for id in ids:
sources[id] = deque()
assert not ready(sources)
assert not done(sources)
# All sources must have a message to be ready/done
sources['a'].append(to_dt("datetime"))
assert not ready(sources)
assert not done(sources)
sources['a'].pop()
for id in ids:
sources[id].append(to_dt("datetime"))
assert ready(sources)
assert not done(sources)
for id in ids:
sources[id].appendleft(to_dt("DONE"))
# ["DONE", message] will trip an assert in queue_is_done.
assert ready(sources)
self.assertRaises(AssertionError, done, sources)
for id in ids:
sources[id].pop()
assert ready(sources)
assert done(sources)
class DateSortTestCase(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def run_date_sort(self, event_stream, expected, source_ids):
"""
Take a list of events, their source_ids, and an expected sorting.
Assert that date_sort's output agrees with expected.
"""
sort_out = date_sort(event_stream, source_ids)
for m1, m2 in izip_longest(sort_out, expected):
assert m1 == m2
def test_single_source(self):
# Just using the built-in defaults. See
# zipline/gens/tradegens.py
source = SpecificEquityTrades()
expected = list(source)
source.rewind()
# The raw source doesn't handle done messaging, so we need to
# append a done message for sort to work properly.
with_done = chain(source, [done_message(source.get_hash())])
self.run_date_sort(with_done, expected, [source.get_hash()])
def test_multi_source(self):
filter = [2,3]
args_a = tuple()
kwargs_a = {
'count' : 100,
'sids' : [1,2,3],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 6),
'filter' : filter
}
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
args_b = tuple()
kwargs_b = {
'count' : 100,
'sids' : [2,3,4],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 5),
'filter' : filter
}
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
all_events = list(chain(source_a, source_b))
# The expected output is all events, sorted by dt with
# source_id as a tiebreaker.
expected = sorted(all_events, comp)
source_ids = [source_a.get_hash(), source_b.get_hash()]
# Generating the events list consumes the sources. Rewind them
# for testing.
source_a.rewind()
source_b.rewind()
# Append a done message to each source.
with_done_a = chain(source_a, [done_message(source_a.get_hash())])
with_done_b = chain(source_b, [done_message(source_b.get_hash())])
interleaved = alternate(with_done_a, with_done_b)
# Test sort with alternating messages from source_a and
# source_b.
self.run_date_sort(interleaved, expected, source_ids)
source_a.rewind()
source_b.rewind()
with_done_a = chain(source_a, [done_message(source_a.get_hash())])
with_done_b = chain(source_b, [done_message(source_b.get_hash())])
sequential = chain(with_done_a, with_done_b)
# Test sort with all messages from a, followed by all messages
# from b.
self.run_date_sort(sequential, expected, source_ids)
def test_sort_composite(self):
filter = [1,2]
#Set up source a. One hour between events.
args_a = tuple()
kwargs_a = {
'count' : 100,
'sids' : [1],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(hours = 1),
'filter' : filter
}
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
#Set up source b. One day between events.
args_b = tuple()
kwargs_b = {
'count' : 50,
'sids' : [2],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(days = 1),
'filter' : filter
}
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
#Set up source c. One minute between events.
args_c = tuple()
kwargs_c = {
'count' : 150,
'sids' : [1,2],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
source_c = SpecificEquityTrades(*args_c, **kwargs_c)
# Set up source d. This should produce no events because the
# internal sids don't match the filter.
args_d = tuple()
kwargs_d = {
'count' : 50,
'sids' : [3],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
source_d = SpecificEquityTrades(*args_d, **kwargs_d)
sources = [source_a, source_b, source_c, source_d]
hashes = [source.get_hash() for source in sources]
sort_out = date_sorted_sources(*sources)
# Read all the values from sort and assert that they arrive in
# the correct sorting with the expected hash values.
to_list = list(sort_out)
copy = to_list[:]
# We should have 300 events (100 from a, 150 from b, 50 from c)
assert len(to_list) == 300
for e in to_list:
# All events should match one of our expected source_ids.
assert e.source_id in hashes
# But none of them should match source_d.
assert e.source_id != source_d.get_hash()
# The events should be sorted by dt, with source_id as tiebreaker.
expected = sorted(copy, comp)
assert to_list == expected
def compare_by_dt_source_id(x,y):
if x.dt < y.dt:
return -1
elif x.dt > y.dt:
return 1
elif x.source_id < y.source_id:
return -1
elif x.source_id > y.source_id:
return 1
else:
return 0
#Alias for ease of use
comp = compare_by_dt_source_id
def to_dt(msg):
return ndict({'dt': msg})
+221 -57
View File
@@ -1,102 +1,266 @@
from datetime import timedelta
import pytz
from datetime import timedelta, datetime
from collections import defaultdict
from unittest2 import TestCase
from zipline import ndict
from zipline.lines import SimulatedTrading
from zipline.utils.test_utils import setup_logger, teardown_logger
from zipline.utils.date_utils import utcnow
from zipline.gens.tradegens import SpecificEquityTrades
from zipline.gens.transform import StatefulTransform, EventWindow
from zipline.gens.vwap import VWAP
from zipline.gens.mavg import MovingAverage
from zipline.gens.returns import Returns
import zipline.utils.factory as factory
from zipline.finance.vwap import DailyVWAP, VWAPTransform
from zipline.finance.returns import ReturnsFromPriorClose
from zipline.finance.movingaverage import MovingAverage
from zipline.lines import SimulatedTrading
from zipline.core.devsimulator import AddressAllocator
allocator = AddressAllocator(1000)
def to_dt(msg):
return ndict({'dt': msg})
class ZiplineWithTransformsTestCase(TestCase):
leased_sockets = defaultdict(list)
class NoopEventWindow(EventWindow):
"""
A no-op EventWindow subclass for testing the base EventWindow logic.
Keeps lists of all added and dropped events.
"""
def __init__(self, market_aware, days, delta):
EventWindow.__init__(self, market_aware, days, delta)
self.added = []
self.removed = []
def handle_add(self, event):
self.added.append(event)
def handle_remove(self, event):
self.removed.append(event)
class EventWindowTestCase(TestCase):
def setUp(self):
# skip ahead 100 spots
allocator.lease(100)
self.trading_environment = factory.create_trading_environment()
self.zipline_test_config = {
'allocator' : allocator,
'sid' : 133,
'devel' : True
}
setup_logger(self, '/var/log/qexec/qexed.log')
setup_logger(self)
# Constants calling before open, during the day, and after
# close on a valid trading day.
self.pre_open = datetime(2012, 8, 7, 13, tzinfo = pytz.utc)
self.mid_day = datetime(2012, 8, 7, 15, tzinfo = pytz.utc)
self.post_close = datetime(2012, 8, 7, 22, tzinfo = pytz.utc)
def tearDown(self):
teardown_logger(self)
# Constants calling before open, during the day, and after
# close on a saturday.
self.pre_open_saturday = datetime(2012, 8, 11, 13, tzinfo = pytz.utc)
self.mid_day_saturday = datetime(2012, 8, 11, 15, tzinfo = pytz.utc)
self.post_close_saturday = datetime(2012, 8, 11, 22, tzinfo = pytz.utc)
def test_vwap_tnfm(self):
zipline = SimulatedTrading.create_test_zipline(
**self.zipline_test_config
# Constants calling before open, during the day, and after
# close on a holiday.
self.pre_open_holiday = datetime(2012, 12, 25, 13, tzinfo = pytz.utc)
self.mid_day_holiday = datetime(2012, 12, 25, tzinfo = pytz.utc)
self.post_close_holiday = datetime(2012, 12, 25, 22, tzinfo = pytz.utc)
def test_event_window_with_timedelta(self):
# Keep all events within a 5 minute window.
window = NoopEventWindow(
market_aware = False,
delta = timedelta(minutes = 5),
days = None
)
vwap = VWAPTransform("vwap_10", daycount=10)
zipline.add_transform(vwap)
now = utcnow()
zipline.simulate(blocking=True)
# 15 dates, increasing in 1 minute increments.
dates = [now + i * timedelta(minutes = 1)
for i in xrange(15)]
self.assertTrue(zipline.sim.ready())
self.assertFalse(zipline.sim.exception)
# Turn the dates into the format required by EventWindow.
dt_messages = [to_dt(date) for date in dates]
# Run all messages through the window and assert that we're adding
# and removing messages appropriately. We start the enumeration at 1
# for convenience.
for num, message in enumerate(dt_messages, 1):
window.update(message)
# Assert that we've added the correct number of events.
assert len(window.added) == num
# Assert that we removed only events that fall outside (or
# on the boundary of) the delta.
for dropped in window.removed:
assert message.dt - dropped.dt >= timedelta(minutes = 5)
def test_market_aware_window(self):
window = NoopEventWindow(
market_aware = True,
delta = None,
days = 1
)
dates = ([self.pre_open]*3)
dates += ([self.mid_day]*3)
dates += ([self.post_close]*3)
dates += [self.pre_open + timedelta(days = 1, seconds = 1)]
events = [to_dt(date) for date in dates]
# Run the events.
for event in events:
window.update(event)
# We should have removed the pre_open events on the first day.
# The rest should be intact.
assert window.added == events
assert window.removed == events[0:3]
assert list(window.ticks) == events[3:]
def test_market_aware_window_weekend(self):
window = NoopEventWindow(
market_aware = True,
delta = None,
days = 2
)
dates = [self.pre_open_saturday - timedelta(days = 1, seconds=1)]
dates += [self.mid_day_saturday - timedelta(days = 1, seconds=1)]
dates += [self.post_close_saturday - timedelta(days = 1, seconds=1)]
dates += [self.mid_day_saturday + timedelta(days = 1)]
events = [to_dt(date) for date in dates]
# Run the events.
for event in events:
window.update(event)
# We shouldn't remove any events.
assert window.added == events
assert window.removed == []
assert list(window.ticks) == events
extra = to_dt(self.mid_day_saturday + timedelta(days = 2))
window.update(extra)
# We should remove only the first event.
assert window.removed == [events[0]]
assert list(window.ticks) == events[1:] + [extra]
def tearDown(self):
setup_logger(self)
class FinanceTransformsTestCase(TestCase):
def setUp(self):
self.trading_environment = factory.create_trading_environment()
setup_logger(self, '/var/log/qexec/qexec.log')
setup_logger(self)
trade_history = factory.create_trade_history(
133,
[10.0, 10.0, 11.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
self.trading_environment
)
self.source = SpecificEquityTrades(event_list=trade_history)
def tearDown(self):
self.log_handler.pop_application()
def test_vwap(self):
trade_history = factory.create_trade_history(
133,
[10.0, 10.0, 10.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
self.trading_environment
vwap = StatefulTransform(
VWAP,
market_aware = False,
delta = timedelta(days = 2)
)
transformed = list(vwap.transform(self.source))
vwap = DailyVWAP(days=2)
for trade in trade_history:
vwap.update(trade)
self.assertEqual(vwap.vwap, 10.75)
# Output values
tnfm_vals = [message.tnfm_value for message in transformed]
# "Hand calculated" values.
expected = [
(10.0 * 100) / 100.0,
((10.0 * 100) + (10.0 * 100)) / (200.0),
# We should drop the first event here.
((10.0 * 100) + (11.0 * 100)) / (200.0),
# We should drop the second event here.
((11.0 * 100) + (11.0 * 300)) / (400.0)
]
# Output should match the expected.
assert tnfm_vals == expected
def test_returns(self):
# Daily returns.
returns = StatefulTransform(Returns, 1)
transformed = list(returns.transform(self.source))
tnfm_vals = [message.tnfm_value for message in transformed]
# No returns for the first event because we don't have a
# previous close.
expected = [0.0, 0.0, 0.1, 0.0]
assert tnfm_vals == expected
# Two-day returns. An extra kink here is that the
# factory will automatically skip a weekend for the
# last event. Results shouldn't notice this blip.
trade_history = factory.create_trade_history(
133,
[10.0, 10.0, 10.0, 11.0],
[100, 100, 100, 300],
[10.0, 15.0, 13.0, 12.0, 13.0],
[100, 100, 100, 300, 100],
timedelta(days=1),
self.trading_environment
)
self.source = SpecificEquityTrades(event_list=trade_history)
returns = ReturnsFromPriorClose()
for trade in trade_history:
returns.update(trade)
returns = StatefulTransform(Returns, 2)
transformed = list(returns.transform(self.source))
tnfm_vals = [message.tnfm_value for message in transformed]
self.assertEqual(returns.returns, .1)
expected = [
0.0,
0.0,
(13.0 - 10.0) / 10.0,
(12.0 - 15.0) / 15.0,
(13.0 - 13.0) / 13.0
]
assert tnfm_vals == expected
def test_moving_average(self):
trade_history = factory.create_trade_history(
133,
[10.0, 10.0, 10.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
self.trading_environment
mavg = StatefulTransform(
MovingAverage,
market_aware = False,
fields = ['price', 'volume'],
delta = timedelta(days = 2),
)
transformed = list(mavg.transform(self.source))
# Output values.
tnfm_prices = [message.tnfm_value.price for message in transformed]
tnfm_volumes = [message.tnfm_value.volume for message in transformed]
ma = MovingAverage(days=2)
for trade in trade_history:
ma.update(trade)
# "Hand-calculated" values
expected_prices = [
((10.0) / 1.0),
((10.0 + 10.0) / 2.0),
# First event should get dropped here.
((10.0 + 11.0) / 2.0),
# Second event should get dropped here.
((11.0 + 11.0) / 2.0)
]
expected_volumes = [
((100.0) / 1.0),
((100.0 + 100.0) / 2.0),
# First event should get dropped here.
((100.0 + 100.0) / 2.0),
# Second event should get dropped here.
((100.0 + 300.0) / 2.0)
]
self.assertEqual(ma.average, 10.5)
assert tnfm_prices == expected_prices
assert tnfm_volumes == expected_volumes
-6
View File
@@ -6,15 +6,9 @@ Zipline
# it is a place to expose the public interfaces.
import protocol # namespace
from core.monitor import Controller
from lines import SimulatedTrading
from core.host import ComponentHost
from utils.protocol_utils import ndict
__all__ = [
SimulatedTrading,
Controller,
ComponentHost,
protocol,
ndict
]
+12 -28
View File
@@ -12,7 +12,7 @@ from zipline.utils.protocol_utils import ndict
from zipline.utils.log_utils import ZeroMQLogHandler, stdout_only_pipe
from logbook import Logger, NestedSetup, Processor, queues
from logbook import Logger, NestedSetup, Processor
log = logbook.Logger('TradeSimulation')
@@ -20,7 +20,7 @@ log = logbook.Logger('TradeSimulation')
class TradeSimulationClient(Component):
def init(self, trading_environment, sim_style, results_socket):
def init(self, trading_environment, sim_style, results_socket, algorithm):
self.received_count = 0
self.prev_dt = None
self.event_queue = None
@@ -29,13 +29,20 @@ class TradeSimulationClient(Component):
self.trading_environment = trading_environment
self.current_dt = trading_environment.period_start
self.last_iteration_dur = datetime.timedelta(seconds=0)
self.algorithm = None
self.algorithm = algorithm
self.algorithm.set_order(self.order)
self.max_wait = datetime.timedelta(seconds=60)
self.last_msg_dt = datetime.datetime.utcnow()
self.txn_sim = TransactionSimulator(sim_style)
self.txn_sim = TransactionSimulator(
open_orders={},
style=sim_style
)
self.event_data = ndict()
self.perf = perf.PerformanceTracker(self.trading_environment)
self.perf = perf.PerformanceTracker(
self.trading_environment,
self.algorithm.get_sid_filter()
)
self.zmq_out = None
self.results_socket = results_socket
self.algo_initialized = False
@@ -44,19 +51,6 @@ class TradeSimulationClient(Component):
def get_id(self):
return str(zp.FINANCE_COMPONENT.TRADING_CLIENT)
def set_algorithm(self, algorithm):
"""
:param algorithm: must implement the algorithm protocol. See
:py:mod:`zipline.test.algorithm`
"""
self.algorithm = algorithm
# register the client's order method with the algorithm
self.algorithm.set_order(self.order)
# we need to provide the performance tracker with the
# sids referenced in the algorithm, so portfolio can
# initialize with all possible sids.
self.perf.set_sids(self.algorithm.get_sid_filter())
def open(self):
self.result_feed = self.connect_result()
if self.results_socket:
@@ -185,16 +179,6 @@ class TradeSimulationClient(Component):
# LOG_EXTRA_FIELDS in zipline/protocol.py
self.do_op(self.algorithm.handle_data, data)
def exception_callback(self, exc_type, exc_value, exc_traceback):
if self.results_socket:
log.info("Sending exception frame")
msg = zp.EXCEPTION_FRAME(
exc_traceback,
exc_type.__name__,
exc_value.message
)
self.out_socket.send(msg)
def do_op(self, callable_op, *args, **kwargs):
""" Wrap a callable operation with the zmq logbook
handler if it exits."""
+2 -2
View File
@@ -1,9 +1,9 @@
from host import ComponentHost
from component import Component
from monitor import Controller
from monitor import Monitor
__all__ = [
Component,
Controller,
Monitor,
ComponentHost
]
+361 -495
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -103,7 +103,7 @@ class ComponentHost(object):
log.info('== Roll Call ==')
log.info('Controller')
log.info('Monitor')
self.launch_controller()
+79 -55
View File
@@ -1,4 +1,3 @@
import inspect
import os
import zmq
import sys
@@ -10,8 +9,13 @@ from signal import SIGHUP, SIGINT
from collections import OrderedDict, Counter
from zipline.protocol import CONTROL_PROTOCOL, CONTROL_FRAME, \
CONTROL_UNFRAME, CONTROL_STATES, INVALID_CONTROL_FRAME \
from zipline.protocol import (
CONTROL_PROTOCOL,
CONTROL_FRAME,
CONTROL_UNFRAME,
CONTROL_STATES,
INVALID_CONTROL_FRAME
)
from zipline.utils.protocol_utils import ndict
@@ -34,7 +38,7 @@ class UnknownChatter(Exception):
return """Component calling itself "%s" talking on unexpected channel""" % self.named
log = logbook.Logger('Controller')
log = logbook.Logger('Monitor')
# The scalars determining the timing of the monitor behavior for
# the system.
@@ -52,7 +56,7 @@ PARAMETERS = ndict(dict(
SYSTEM_TIMEOUT = 50,
))
class Controller(object):
class Monitor(object):
"""
A N to M messaging system for inter component communication.
@@ -69,7 +73,12 @@ class Controller(object):
debug = True
period = PARAMETERS.GENERATIONAL_PERIOD
def __init__(self, pub_socket, route_socket, send_sighup=False):
def __init__(
self,
pub_socket,
route_socket,
exception_socket,
send_sighup=False):
self.nosignals = False
self.context = None
@@ -90,13 +99,15 @@ class Controller(object):
self.associated = []
self.pub_socket = pub_socket
self.route_socket = route_socket
self.error_replay = OrderedDict()
self.pub_socket = pub_socket
self.route_socket = route_socket
self.exception_socket = exception_socket
self.missed_beats = Counter()
# start with an empty topology
self.topology = set([])
self.send_sighup = send_sighup
if self.send_sighup:
log.info("Request to send sighup/sigint")
@@ -108,6 +119,17 @@ class Controller(object):
self.zmq_poller = self.zmq.Poller
return
def add_to_topology(self, component_id):
add = set([component_id, "FORK-" + component_id])
self.topology.update(add)
def freeze_topology(self):
if isinstance(self.topology, frozenset):
return
# we've been incrementally adding components.
# time to freeze.
self.manage(self.topology)
def manage(self, topology):
"""
Give the controller a set set of components to manage and
@@ -139,6 +161,7 @@ class Controller(object):
raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new))
def run(self):
self.freeze_topology()
self.running = True
self.init_zmq()
setproctitle('Monitor')
@@ -243,6 +266,11 @@ class Controller(object):
self.router.bind(self.route_socket)
self.router.setsockopt(zmq.LINGER, 0)
# -- Exception Out --
# ===================
self.ex_out = self.context.socket(self.zmq.PUSH)
self.ex_out.connect(self.exception_socket)
poller = self.zmq.Poller()
poller.register(self.router, self.zmq.POLLIN)
#poller.register(self.cancel, self.zmq.POLLIN)
@@ -312,6 +340,11 @@ class Controller(object):
log.info("breaking out of initial heartbeat")
break
# Break out if the entire topology told us its DONE
if len(self.finished) == len(self.topology):
break
# ================
# Heartbeat Stats
# ================
@@ -374,6 +407,7 @@ class Controller(object):
"""
if not self.send_sighup:
log.warning("Skipping SIGINT")
return
ppid = os.getpid()
log.warning("Sending SIGINT")
os.kill(ppid, SIGINT)
@@ -403,28 +437,22 @@ class Controller(object):
bad = self.tracked - good - self.finished
new = self.responses - good - self.finished
missing = self.topology - self.tracked - self.finished
for component in new:
self.new(component)
if self.debug:
log.info('New component %r' % component)
for component in bad:
self.fail(component)
self.timed_out(component)
missing = self.topology - self.tracked - self.finished
for component in missing:
if self.debug:
log.info('Missing component %r' % component)
if self.debug:
for component in self.tracked:
if component not in self.topology:
log.info('Uninvited component %r' % component)
for component in self.tracked:
if component not in self.topology:
log.info('Uninvited component %r' % component)
# --------------
# Init Handlers
@@ -460,22 +488,16 @@ class Controller(object):
# Epic Fail Handling
# ------------------
def fail_universal(self):
# TODO: this requires higher order functionality
log.error('Component missed heartbeat, Monitor shutting down system')
self.kill()
def fail(self, component):
def timed_out(self, component):
if self.state is CONTROL_STATES.TERMINATE:
return
universal = self.fail_universal
fail_handlers = { }
if component in (self.topology - self.finished) or self.freeform:
log.warning('Component "%s" missed heartbeat' % component)
self.tracked.remove(component)
fail_handlers.get(component, universal)()
# we treat a time out as a severe failure, and
# conduct a rapid shutdown
self.kill()
# -------------------
# Completion Handling
@@ -489,26 +511,15 @@ class Controller(object):
# --------------
# Error Handling
# --------------
def exception_universal(self):
"""
Shutdown the system on failure.
"""
log.error('System in exception state, shutting down')
def exception(self, component, exception_data):
log.error('Component in exception state: %s. Shutting down system and sending exception data to listeners.'\
% component)
# Send the exception message out to listeners.
self.ex_out.send(exception_data)
# An exception in one component is treated as a hard
# failure, and we conduct a rapid shutdown.
self.kill()
def exception(self, component, failure):
universal = self.exception_universal
exception_handlers = { }
if component in self.topology or self.freeform:
self.error_replay[(component, time.time())] = failure
log.error('Component in exception state: %s' % component)
exception_handlers.get(component, universal)()
else:
raise UnknownChatter(component)
# -----------------
# Protocol Handling
# -----------------
@@ -555,7 +566,19 @@ class Controller(object):
# A component is telling us it failed, and how
if id is CONTROL_PROTOCOL.EXCEPTION:
self.exception(identity, status)
# status should be a msgpack emitted from
# EXCEPTION_FRAME
try:
exception_data = status
self.exception(identity, exception_data)
except:
# if an exception occurs when we try to handle
# the exception, signal the parent that we need
# to go down
# TODO: should we attempt to call self.exception?
log.exception("Unexpected exception sending exception data")
self.kill()
return
# A component is telling us its done with work and won't
@@ -605,11 +628,9 @@ class Controller(object):
self.associated.append(s)
return s
def do_error_replay(self):
for (component, time), error in self.error_replay.iteritems():
log.info('Component Log for -- %s --:\n%s' % (component, error))
def kill(self):
"""Aggressively exit the whole zipline.
"""
if self.state is CONTROL_STATES.TERMINATE:
return
@@ -617,6 +638,9 @@ class Controller(object):
self.send_hardkill()
self.state = CONTROL_STATES.TERMINATE
self.alive = False
# send burrito an interrupt, instructing it to kill all
# child processes assocated with this zipline.
time.sleep(3)
self.signal_interrupt()
def shutdown(self):
+4 -4
View File
@@ -40,13 +40,13 @@ class ProcessSimulator(ComponentHost):
# invoked by the host's open()
def launch_controller(self):
proc = multiprocessing.Process(target=self.controller.run)
proc = multiprocessing.Process(target=self.monitor.run)
proc.start()
self.con = proc
# Process specific
self.controller_process = proc
self.mapping[proc.pid] = 'Controller'
self.monitor_process = proc
self.mapping[proc.pid] = 'Monitor'
def launch_component(self, component):
proc = multiprocessing.Process(target=component.run)
@@ -81,7 +81,7 @@ class ProcessSimulator(ComponentHost):
process.join(timeout=1)
process.terminate()
self.controller.shutdown(soft=True)
self.monitor.shutdown(soft=True)
self.running = False
self.con.terminate()
+32 -23
View File
@@ -133,6 +133,7 @@ import zipline.finance.risk as risk
log = logbook.Logger('Performance')
class PerformanceTracker(object):
UPDATER = True
"""
Tracks the performance of the zipline as it is running in
the simulator, relays this out to the Deluge broker and then
@@ -144,7 +145,7 @@ class PerformanceTracker(object):
"""
def __init__(self, trading_environment):
def __init__(self, trading_environment, sid_list):
self.trading_environment = trading_environment
self.trading_day = datetime.timedelta(hours = 6, minutes = 30)
@@ -164,7 +165,6 @@ class PerformanceTracker(object):
self.txn_count = 0
self.event_count = 0
self.last_dict = None
self.order_log = []
self.exceeded_max_loss = False
self.results_socket = None
@@ -197,10 +197,21 @@ class PerformanceTracker(object):
# save the transactions for the daily periods
keep_transactions = True
)
def set_sids(self, sid_list):
for sid in sid_list:
self.cumulative_performance.positions[sid] = Position(sid)
self.todays_performance.positions[sid] = Position(sid)
def update(self, event):
if event.dt == "DONE":
event.perf_message = self.handle_simulation_end()
del event['TRANSACTION']
return event
else:
event.perf_message = self.process_event(event)
event.portfolio = self.get_portfolio()
del event['TRANSACTION']
return event
def get_portfolio(self):
return self.cumulative_performance.as_portfolio()
@@ -238,10 +249,10 @@ class PerformanceTracker(object):
'cumulative_risk_metrics' : self.cumulative_risk_metrics.to_dict()
}
def log_order(self, order):
self.order_log.append(order)
def process_event(self, event):
message = None
if self.exceeded_max_loss:
return
@@ -250,7 +261,7 @@ class PerformanceTracker(object):
self.event_count += 1
if(event.dt >= self.market_close):
self.handle_market_close()
message = self.handle_market_close()
if event.TRANSACTION:
self.txn_count += 1
@@ -264,9 +275,12 @@ class PerformanceTracker(object):
#calculate performance as of last trade
self.cumulative_performance.calculate_performance()
self.todays_performance.calculate_performance()
return message
def handle_market_close(self):
# add the return results from today to the list of DailyReturn objects.
todays_date = self.market_close.replace(hour=0, minute=0, second=0)
todays_return_obj = risk.DailyReturn(
@@ -288,12 +302,10 @@ class PerformanceTracker(object):
# calculate progress of test
self.progress = self.day_count / self.total_days
# Output results
if self.results_socket:
msg = zp.PERF_FRAME(self.to_dict())
self.results_socket.send(msg)
# Take a snapshot of our current peformance to return to the
# browser.
daily_update = self.to_dict()
#
if self.trading_environment.max_drawdown:
returns = self.todays_performance.returns
max_dd = -1 * self.trading_environment.max_drawdown
@@ -304,7 +316,7 @@ class PerformanceTracker(object):
# so it shows up in the update, but don't end the test
# here. Let the update go out before stopping
self.exceeded_max_loss = True
return
return daily_update
#move the market day markers forward
@@ -326,6 +338,8 @@ class PerformanceTracker(object):
self.market_close,
keep_transactions = True
)
return daily_update
def handle_simulation_end(self):
"""
@@ -349,12 +363,8 @@ class PerformanceTracker(object):
exceeded_max_loss = self.exceeded_max_loss
)
if self.results_socket:
log.info("about to stream the risk report...")
risk_dict = self.risk_report.to_dict()
msg = zp.RISK_FRAME(risk_dict)
self.results_socket.send(msg)
risk_dict = self.risk_report.to_dict()
return risk_dict
class Position(object):
@@ -362,8 +372,8 @@ class Position(object):
self.sid = sid
self.amount = 0
self.cost_basis = 0.0 ##per share
self.last_sale_price = None
self.last_sale_date = None
self.last_sale_price = 0.0
self.last_sale_date = 0.0
def update(self, txn):
if(self.sid != txn.sid):
@@ -584,7 +594,6 @@ class PerformancePeriod(object):
return positions
#
def get_positions_list(self):
positions = []
for sid, pos in self.positions.iteritems():
-49
View File
@@ -1,49 +0,0 @@
from collections import defaultdict
from zipline.transforms.base import BaseTransform
class ReturnsTransform(BaseTransform):
def init(self, name):
self.state = {}
self.state['name'] = name
self.by_sid = defaultdict(self._create)
@property
def get_id(self):
return self.state['name']
def transform(self, event):
cur = self.by_sid[event.sid]
cur.update(event)
self.state['value'] = cur.returns
return self.state
def _create(self):
return ReturnsFromPriorClose()
class ReturnsFromPriorClose(object):
"""
Calculates a security's returns since the previous close, using the
current price.
"""
def __init__(self):
self.last_close = None
self.last_event = None
self.returns = 0.0
def update(self, event):
if self.last_close:
change = event.price - self.last_close.price
self.returns = change / self.last_close.price
if self.last_event:
if self.last_event.dt.day != event.dt.day:
# the current event is from the day after
# the last event. Therefore the last event was
# the last close
self.last_close = self.last_event
# the current event is now the last_event
self.last_event = event
+15 -23
View File
@@ -9,10 +9,10 @@ from zipline.protocol import SIMULATION_STYLE
log = logbook.Logger('Transaction Simulator')
class TransactionSimulator(object):
UPDATER = True
def __init__(self, style=SIMULATION_STYLE.PARTIAL_VOLUME):
def __init__(self, sid_filter, style=SIMULATION_STYLE.PARTIAL_VOLUME):
self.open_orders = {}
self.order_count = 0
self.txn_count = 0
self.trade_window = datetime.timedelta(seconds=30)
self.orderTTL = datetime.timedelta(days=1)
@@ -27,27 +27,20 @@ class TransactionSimulator(object):
elif style == SIMULATION_STYLE.NOOP:
self.apply_trade_to_open_orders = self.simulate_noop
def add_open_order(self, event):
# Orders are captured in a buffer by sid. No calculations are done here.
# Amount is explicitly converted to an int.
# Orders of amount zero are ignored.
for sid in sid_filter:
self.open_orders[sid] = []
self.order_count += 1
event.amount = int(event.amount)
def place_order(self, order):
# initialized filled field.
order.filled = 0
self.open_orders[order.sid].append(order)
if event.amount == 0:
log = "requested to trade zero shares of {sid}".format(
sid=event.sid
)
log.debug(log)
return
if not self.open_orders.has_key(event.sid):
self.open_orders[event.sid] = []
# set the filled property to zero
event.filled = 0
self.open_orders[event.sid].append(event)
def update(self, event):
event.TRANSACTION = None
# We only fill transactions on trade events.
if event.type == zp.DATASOURCE_TYPE.TRADE:
event.TRANSACTION = self.apply_trade_to_open_orders(event)
return event
def simulate_buy_all(self, event):
txn = self.create_transaction(
@@ -81,7 +74,7 @@ class TransactionSimulator(object):
txn = self.create_transaction(
event.sid,
amount,
event.price + 0.10,
event.price + 0.10, # Magic constant?
event.dt,
direction
)
@@ -162,7 +155,6 @@ class TransactionSimulator(object):
}
return zp.ndict(txn)
class TradingEnvironment(object):
def __init__(
+78 -68
View File
@@ -1,91 +1,101 @@
import datetime
from itertools import tee, starmap
from itertools import tee, starmap, chain
from collections import namedtuple
from zipline.gens.tradegens import SpecificEquityTrades
from zipline.gens.utils import roundrobin, hash_args
from zipline.gens.utils import roundrobin, hash_args, done_message
from zipline.gens.sort import date_sort
from zipline.gens.merge import merge
from zipline.gens.transform import stateful_transform
from zipline.gens.transform import StatefulTransform
SortBundle = namedtuple("SortBundle", ['source', 'args', 'kwargs'])
MergeBundle = namedtuple("MergeBundle", ['stream', 'tnfm', 'args', 'kwargs'])
def date_sorted_sources(sources, source_args, source_kwargs):
def date_sorted_sources(*sources):
"""
Takes a list of generator functions, a list of tuples of positional arguments,
and a list of dictionaries of keyword arguments. Packages up all arguments
and passes them into a date_sort.
Takes an iterable of sources, generating namestrings and
piping their output into date_sort.
"""
assert len(sources) == len(source_args) == len(source_kwargs)
# Package up sources and arguments.
# Create a generator of SortBundle objects to be turned into
# namestrings and generator objects.
bundle_gen = starmap(SortBundle, zip(sources, source_args, source_kwargs))
# Load the results of the generator into a tuple so that the
# results can be used twice (once in namestring comprehension,
# once in the generator comprehension for intialized sources.
bundles = tuple(bundle_gen)
for source in sources:
assert iter(source), "Source %s not iterable" % source
assert source.__class__.__dict__.has_key('get_hash'), "No get_hash"
# Calculate namestring hashes to pass to date_sort.
names = [bundle.source.__name__ + hash_args(*bundle.args, **bundle.kwargs)
for bundle in bundles]
# Pass each source its arguments.
initialized = [bundle.source(*bundle.args, **bundle.kwargs)
for bundle in bundles]
# Get name hashes to pass to date_sort.
names = [source.get_hash() for source in sources]
# Convert the list of generators into a flat stream by pulling
# one element at a time from each.
stream_in = roundrobin(*initialized)
# Guarantee the flat stream will be sorted by date, using source_id as
# tie-breaker, which is fully deterministic (given deterministic string
# representation for all args/kwargs)
stream_in = roundrobin(sources, names)
# Guarantee the flat stream will be sorted by date, using
# source_id as tie-breaker, which is fully deterministic (given
# deterministic string representation for all args/kwargs)
return date_sort(stream_in, names)
def merged_transforms(sorted_stream, tnfms, tnfm_args, tnfm_kwargs):
def merged_transforms(sorted_stream, *transforms):
"""
A generator that takes the expected output of a date_sort, pipes it
through a given set of transforms, and runs the results throught a
merge to output a unified stream. tnfms should be a list of
pointers to generator functions. tnfm_args should be a list of
tuples, representing the arguments to be passed to each transform.
tnfm_kwargs should be a list of dictionaries representing keyword
arguments to each transform.
A generator that takes the expected output of a date_sort, pipes
it through a given set of transforms, and runs the results
through a merge to output a unified stream. tnfms should be a
list of pointers to generator functions. tnfm_args should be a
list of tuples, representing the arguments to be passed to each
transform. tnfm_kwargs should be a list of dictionaries
representing keyword arguments to each transform.
"""
# We should have as many sets of args as we have transforms.
assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs)
# Create a copy of the stream for each transform.
split = tee(sorted_stream, len(tnfms))
# Package each transform with a stream copy and set of args. Use a list
# so that we can re-use this for calculating hashes.
bundle_gen = starmap(MergeBundle, zip(split, tnfms, tnfm_args, tnfm_kwargs))
bundles = tuple(bundle_gen)
# list comprehension to create transform generators from
# bundles
tnfm_gens = [
stateful_transform(
bundle.stream,
bundle.tnfm,
*bundle.args,
**bundle.kwargs
)
for bundle in bundles]
for transform in transforms:
assert isinstance(transform, StatefulTransform)
# Generate expected hashes for each transform
hashes = [bundle.tnfm.__name__ + hash_args(*bundle.args, **bundle.kwargs)
for bundle in bundles]
namestrings = [tnfm.get_hash() for tnfm in transforms]
# Roundrobin the outputs of our transforms to create a single flat stream.
to_merge = roundrobin(*tnfm_gens)
# Create a copy of the stream for each transform.
split = tee(sorted_stream, len(transforms))
# Package a stream copy with each StatefulTransform instance.
bundles = zip(transforms, split)
# Convert the copies into transform streams.
tnfm_gens = [tnfm.transform(stream) for tnfm, stream in bundles]
# Roundrobin the outputs of our transforms to create a single flat
# stream.
to_merge = roundrobin(tnfm_gens, namestrings)
# Pipe the stream into merge.
merged = merge(to_merge, hashes)
return merged_transforms
merged = merge(to_merge, namestrings)
dt_aliased = alias_dt(merged)
# Return the merged events.
return add_done(dt_aliased)
def sequential_transforms(stream_in, *transforms):
"""
Apply each transform in transforms sequentially to each event in stream_in.
Each transform application will add a new entry indexed to the transform's
hash string.
"""
assert isinstance(transforms, (list, tuple))
for tnfm in transforms:
tnfm.forward_all = False
tnfm.update_in_place = False
tnfm.append_value = True
# Recursively apply all transforms to the stream.
stream_out = reduce(lambda stream, tnfm: tnfm.transform(stream),
transforms,
stream_in)
dt_aliased = alias_dt(stream_out)
return add_done(dt_aliased)
def alias_dt(stream_in):
"""
Alias the dt field to datetime on each message.
"""
for message in stream_in:
message['datetime'] = message['dt']
yield message
# Add a done message to a stream.
def add_done(stream_in):
return chain(stream_in, [done_message('Composite')])
+100
View File
@@ -0,0 +1,100 @@
import pytz
import time
from time import sleep
from pprint import pprint as pp
from datetime import datetime, timedelta
from itertools import izip
from zipline.utils.factory import create_trading_environment
from zipline.test_algorithms import TestAlgorithm
from zipline.gens.composites import SourceBundle, TransformBundle, \
date_sorted_sources, merged_transforms, sequential_transforms
from zipline.gens.tradegens import SpecificEquityTrades
from zipline.gens.transform import MovingAverage, Passthrough, StatefulTransform
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
import zipline.protocol as zp
if __name__ == "__main__":
filter = [2,3]
#Set up source a. Six minutes between events.
args_a = tuple()
kwargs_a = {
'count' : 1000,
'sids' : [1,2,3],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 6),
'filter' : filter
}
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
source_a_prime = SpecificEquityTrades(*args_a, **kwargs_a)
#Set up source b. Five minutes between events.
args_b = tuple()
kwargs_b = {
'count' : 1000,
'sids' : [2,3,4],
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 5),
'filter' : filter
}
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
source_b_prime = SpecificEquityTrades(*args_b, **kwargs_b)
sorted = date_sorted_sources(source_a, source_b)
sorted_prime = date_sorted_sources(
source_a_prime,
source_b_prime
)
passthrough = StatefulTransform(Passthrough)
mavg_price = StatefulTransform(
MovingAverage,
timedelta(minutes = 20),
['price']
)
passthrough_prime = StatefulTransform(Passthrough)
mavg_price_prime = StatefulTransform(
MovingAverage,
timedelta(minutes = 20),
['price']
)
merged = merged_transforms(sorted, passthrough, mavg_price)
start = time.time()
for message in merged:
assert 1 + 1 == 2
stop = time.time()
merge_time = stop - start
print "Merge time: %s" % str(merge_time)
sequential = sequential_transforms(
sorted_prime,
passthrough_prime,
mavg_price_prime
)
start = time.time()
for message in sequential:
assert 1 + 1 == 2
stop = time.time()
seq_time = stop - start
print "Sequential time: %s" % str(seq_time)
print "Merge/Seq: %s" % (str(merge_time/seq_time))
# merged = merged_transforms(sorted, passthrough, mavg_price)
# algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
# environment = create_trading_environment(year = 2012)
# style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE
# trading_client = tsc(algo, environment, style)
# for message in trading_client.simulate(merged):
# pp(message)
+127
View File
@@ -0,0 +1,127 @@
from numbers import Number
from datetime import datetime, timedelta
from collections import defaultdict
from zipline import ndict
from zipline.gens.transform import EventWindow
class MovingAverage(object):
"""
Class that maintains a dictionary from sids to
MovingAverageEventWindows. For each sid, we maintain moving
averages over any number of distinct fields (For example, we can
maintain a sid's average volume as well as its average price.)
"""
def __init__(self, fields, market_aware, days = None, delta = None):
self.fields = fields
self.market_aware = market_aware
self.delta = delta
self.days = days
# Market-aware mode only works with full-day windows.
if self.market_aware:
assert self.days and not self.delta,\
"Market-aware mode only works with full-day windows."
# Non-market-aware mode requires a timedelta.
else:
assert self.delta and not self.days, \
"Non-market-aware mode requires a timedelta."
# No way to pass arguments to the defaultdict factory, so we
# need to define a method to generate the correct EventWindows.
self.sid_windows = defaultdict(self.create_window)
def create_window(self):
"""
Factory method for self.sid_windows.
"""
return MovingAverageEventWindow(
self.fields,
self.market_aware,
self.days,
self.delta
)
def update(self, event):
"""
Update the event window for this event's sid. Return an ndict
from tracked fields to moving averages.
"""
# This will create a new EventWindow if this is the first
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
return window.get_averages()
class MovingAverageEventWindow(EventWindow):
"""
Iteratively calculates moving averages for a particular sid over a
given time window. We can maintain averages for arbitrarily many
fields on a single sid. (For example, we might track average
price as well as average volume for a single sid.) The expected
functionality of this class is to be instantiated inside a
MovingAverage transform.
"""
def __init__(self, fields, market_aware, days, delta):
# Call the superclass constructor to set up base EventWindow
# infrastructure.
EventWindow.__init__(self, market_aware, days, delta)
# We maintain a dictionary of totals for each of our tracked
# fields.
self.fields = fields
self.totals = defaultdict(float)
# Subclass customization for adding new events.
def handle_add(self, event):
# Sanity check on the event.
self.assert_required_fields(event)
# Increment our running totals with data from the event.
for field in self.fields:
self.totals[field] += event[field]
# Subclass customization for removing expired events.
def handle_remove(self, event):
# Decrement our running totals with data from the event.
for field in self.fields:
self.totals[field] -= event[field]
def average(self, field):
"""
Calculate the average value of our ticks over a single field.
"""
# Sanity check.
assert field in self.fields
# Averages are None by convention if we have no ticks.
if len(self.ticks) == 0:
return 0.0
# Calculate and return the average. len(self.ticks) is O(1).
else:
return self.totals[field] / len(self.ticks)
def get_averages(self):
"""
Return an ndict of all our tracked averages.
"""
out = ndict()
for field in self.fields:
out[field] = self.average(field)
return out
def assert_required_fields(self, event):
"""
We only allow events with all of our tracked fields.
"""
for field in self.fields:
assert event.has_key(field), \
"Event missing [%s] in MovingAverageEventWindow" % field
assert isinstance(event[field], Number), \
"Got %s for %s in MovingAverageEventWindow" % (event[field], field)
+28 -16
View File
@@ -6,13 +6,13 @@ from collections import deque
from zipline import ndict
from zipline.gens.utils import hash_args, \
assert_merge_protocol
assert_merge_protocol, done_message
from itertools import repeat
def merge(stream_in, tnfm_ids):
"""
A generator that takes a generator and a list of source_ids. We
maintain an internal queue for each id in source_ids. Once we
A generator that takes a generator and a list of transform ids. We
maintain an internal queue for each id in tnfm_ids. Once we
have a message from every queue, we pop an event from each queue
and merge them together into an event. We raise an error if we
do not receive the same number of events from all sources.
@@ -28,22 +28,22 @@ def merge(stream_in, tnfm_ids):
# Process incoming streams.
for message in stream_in:
assert isinstance(message, tuple), \
"Bad message in merge: %s" %message
assert len(message) == 2
id, value = message
assert isinstance(message, ndict)
assert message.has_key('tnfm_id')
assert message.has_key('tnfm_value')
assert message.has_key('dt')
id = message.tnfm_id
assert id in tnfm_ids, \
"Message from unexpected tnfm: %s, %s" % (id, tnfm_ids)
assert isinstance(value, ndict), "Bad message in merge: %s" %message
tnfms[id].append(value)
tnfms[id].append(message)
# Only pop messages when we have a pending message from
# all datasources. Stop if all sources have signalled done.
while ready(tnfms) and not done(tnfms):
message = merge_one(tnfms)
assert_merge_protocol(tnfm_ids, message)
yield message
# We should have only a done message left in each queue.
@@ -51,13 +51,25 @@ def merge(stream_in, tnfm_ids):
assert len(queue) == 1, "Bad queue in merge on exit: %s" % queue
assert queue[0].dt == "DONE", \
"Bad last message in merge on exit: %s" % queue
def merge_one(sources):
output = ndict()
event_fields = ndict()
for key, queue in sources.iteritems():
new_xform = ndict({key: queue.popleft()})
output.merge(new_xform)
return output
# Add transform value to the transforms dict.
message = queue.popleft()
event_fields[message.tnfm_id] = message.tnfm_value
del message['tnfm_id']
del message['tnfm_value']
# Merge any remaining fields into the event dict.
event_fields.merge(message)
# alias dt with datetime, per algoscript api
event_fields['datetime'] = event_fields['dt']
return event_fields
#TODO: This is replicated in sort. Probably should be one source file.
+74
View File
@@ -0,0 +1,74 @@
from collections import defaultdict, deque
class Returns(object):
"""
Class that maintains a dictionary from sids to the sid's
closing price N trading days ago.
"""
def __init__(self, days):
self.days = days
self.mapping = defaultdict(self._create)
def update(self, event):
"""
Update and return the calculated returns for this event's sid.
"""
assert event.has_key('dt')
assert event.has_key('price')
tracker = self.mapping[event.sid]
tracker.update(event)
return tracker.get_returns()
def _create(self):
return ReturnsFromPriorClose(self.days)
class ReturnsFromPriorClose(object):
"""
Records the last N closing events for a given security as well as the
last event for the security. When we get an event for a new day, we
treat the last event seen as the close for the previous day.
"""
def __init__(self, days):
self.closes = deque()
self.last_event = None
self.returns = 0.0
self.days = days
def get_returns(self):
return self.returns
def update(self, event):
if self.last_event:
# Day has changed since the last event we saw. Treat
# the last event as the closing price for its day and
# clear out the oldest close if it has expired.
if self.last_event.dt.date() != event.dt.date():
self.closes.append(self.last_event)
# We keep an event for the end of each trading day, so
# if the number of stored events is greater than the
# number of days we want to track, the oldest close
# is expired and should be discarded.
while len(self.closes) > self.days:
# Pop the oldest event.
self.closes.popleft()
# We only generate a return value once we've seen enough days
# to give a sensible value. Would be nice if we could query
# db for closes prior to our initial event, but that would
# require giving this transform database creds, which we want
# to avoid.
if len(self.closes) == self.days:
last_close = self.closes[0].price
change = event.price - last_close
self.returns = change / last_close
# the current event is now the last_event
self.last_event = event
+2 -3
View File
@@ -14,7 +14,6 @@ def date_sort(stream_in, source_ids):
have messages pending from all sources, we pull the earliest
message and yield it.
"""
assert isinstance(source_ids, (list, tuple))
# Set up an internal queue for each expected source.
@@ -28,7 +27,7 @@ def date_sort(stream_in, source_ids):
# Incoming messages should be the output of DATASOURCE_UNFRAME.
assert_datasource_unframe_protocol(message), \
"Bad message in date_sort: %s" % message
# Only allow messages from sources we expect.
assert message.source_id in sources, "Unexpected source: %s" % message
@@ -41,7 +40,7 @@ def date_sort(stream_in, source_ids):
message = pop_oldest(sources)
assert_sort_protocol(message)
yield message
# We should have only a done message left in each queue.
for queue in sources.itervalues():
assert len(queue) == 1, "Bad queue in date_sort on exit: %s" % queue
+93 -60
View File
@@ -3,13 +3,14 @@ Tools to generate trade events without a backing store. Useful for testing
and zipline development
"""
import random
import pytz
from itertools import chain, cycle, ifilter, izip
from datetime import datetime, timedelta
from zipline.utils.factory import create_trade
from zipline.gens.utils import hash_args, mock_done
from zipline.gens.utils import hash_args, create_trade
def date_gen(start = datetime(2012, 6, 6, 0),
def date_gen(start = datetime(2006, 6, 6, 12, tzinfo=pytz.utc),
delta = timedelta(minutes = 1),
count = 100):
"""
@@ -25,9 +26,9 @@ def mock_prices(count, rand = False):
"""
if rand:
return (random.uniform(0.0, 10.0) for i in xrange(count))
return (random.uniform(1.0, 10.0) for i in xrange(count))
else:
return (float(i % 11) for i in xrange(1,count+1))
return (float(i % 10) + 1.0 for i in xrange(count))
def mock_volumes(count, rand = False):
"""
@@ -49,72 +50,104 @@ def fuzzy_dates(count = 500):
for date in date_gen(count = count):
yield date + timedelta(seconds = random.randint(-10, 10))
def SpecificEquityTrades(*args, **config):
class SpecificEquityTrades(object):
"""
Yields all events in event_list that match the given sid_filter.
If no event_list is specified, generates an internal stream of events
to filter. Returns all events if filter is None.
Configuration options:
count : integer representing number of trades
sids : list of values representing simulated internal sids
start : start date
delta : timedelta between internal events
filter : filter to remove the sids
"""
# We shouldn't get any positional arguments.
assert args == ()
# Unpack config dictionary with default values.
count = config.get('count', 500)
sids = config.get('sids', [1, 2])
start = config.get('start', datetime(2012, 6, 6, 0))
delta = config.get('delta', timedelta(minutes = 1))
def __init__(self, *args, **kwargs):
# We shouldn't get any positional arguments.
assert len(args) == 0
# Default to None for event_list and filter.
event_list = config.get('event_list')
filter = config.get('filter')
# Unpack config dictionary with default values.
self.count = kwargs.get('count', 500)
self.sids = kwargs.get('sids', [1, 2])
self.start = kwargs.get('start', datetime(2008, 6, 6, 15, tzinfo = pytz.utc))
self.delta = kwargs.get('delta', timedelta(minutes = 1))
arg_string = hash_args(*args, **config)
namestring = "SpecificEquityTrades" + arg_string
# If we have an event_list, ignore the other arguments and use the list.
# TODO: still append our namestring?
if event_list:
unfiltered = (event for event in event_list)
# Default to None for event_list and filter.
self.event_list = kwargs.get('event_list')
self.filter = kwargs.get('filter')
# Set up iterators for each expected field.
else:
dates = date_gen(count = count, start = start, delta = delta)
prices = mock_prices(count)
volumes = mock_volumes(count)
# Hash_value for downstream sorting.
self.arg_string = hash_args(*args, **kwargs)
self.generator = self.create_fresh_generator()
def __iter__(self):
return self
def next(self):
return self.generator.next()
def rewind(self):
self.generator = self.create_fresh_generator()
def get_hash(self):
return self.__class__.__name__ + "-" + self.arg_string
def create_fresh_generator(self):
if self.event_list:
unfiltered = (event for event in self.event_list)
# Set up iterators for each expected field.
else:
dates = date_gen(count=self.count,
start=self.start,
delta=self.delta
)
prices = mock_prices(self.count)
volumes = mock_volumes(self.count)
sids = cycle(self.sids)
# Combine the iterators into a single iterator of arguments
arg_gen = izip(sids, prices, volumes, dates)
# Convert argument packages into events.
unfiltered = (create_trade(*args, source_id = self.get_hash())
for args in arg_gen)
# If we specified a sid filter, filter out elements that don't
# match the filter.
if self.filter:
filtered = ifilter(lambda event: event.sid in self.filter, unfiltered)
# Otherwise just use all events.
else:
filtered = unfiltered
# Return the filtered event stream.
return filtered
# !!!!!!! Deprecated for now !!!!!!!!!
def RandomEquityTrades(object):
def __init__(self):
# We shouldn't get any positional args.
assert args == ()
self.count = config.get('count', 500)
self.sids = config.get('sids', [1,2])
self.filter = config.get('filter')
dates = fuzzy_dates(count)
prices = mock_prices(count, rand = True)
volumes = mock_volumes(count, rand = True)
sids = cycle(sids)
# Combine the iterators into a single iterator of arguments
arg_gen = izip(sids, prices, volumes, dates)
# Convert argument packages into events.
unfiltered = (create_trade(*args, source_id = namestring)
for args in arg_gen)
# If we specified a sid filter, filter out elements that don't match the filter.
if filter:
filtered = ifilter(lambda event: event.sid in filter, unfiltered)
# Otherwise just use all events.
else:
filtered = unfiltered
# Add a done message to the end of the stream. For a live
# datasource this would be handled by the containing Component.
out = chain(filtered, [mock_done(namestring)])
return out
def RandomEquityTrades(*args, **config):
# We shouldn't get any positional args.
assert args == ()
count = config.get('count', 500)
sids = config.get('sids', [1,2])
filter = config.get('filter')
dates = fuzzy_dates(count)
prices = mock_prices(count, rand = True)
volumes = mock_volumes(count, rand = True)
sids = cycle(sids)
arg_gen = izip(sids, prices, volumes, dates)
unfiltered = (create_trade(*args) for args in arg_gen)
+307
View File
@@ -0,0 +1,307 @@
from logbook import Logger, Processor
from datetime import datetime, timedelta
from numbers import Integral
from zipline import ndict
from zipline.gens.transform import StatefulTransform
from zipline.finance.trading import TransactionSimulator
from zipline.finance.performance import PerformanceTracker
from zipline.utils.log_utils import stdout_only_pipe
from zipline.gens.utils import hash_args
log = Logger('Trade Simulation')
class TradeSimulationClient(object):
"""
Generator that takes the expected output of a merge, a user
algorithm, a trading environment, and a simulator style as
arguments. Pipes the merge stream through a TransactionSimulator
and a PerformanceTracker, which keep track of the current state of
our algorithm's simulated universe. Results are fed to the user's
algorithm, which directly inserts transactions into the
TransactionSimulator's order book.
TransactionSimulator maintains a dictionary from sids to the
unfulfilled orders placed by the user's algorithm. As trade
events arrive, if the algorithm has open orders against the
trade's sid, the simulator will fill orders up to 25% of market
cap. Applied transactions are added to a txn field on the event
and forwarded to PerformanceTracker. The txn field is set to None
on non-trade events and events that do not match any open orders.
PerformanceTracker receives the updated event messages from
TransactionSimulator, maintaining a set of daily and cumulative
performance metrics for the algorithm. The tracker removes the
txn field from each event it receives, replacing it with a
portfolio field to be fed into the user algo. At the end of each
trading day, the PerformanceTracker also generates a daily
performance report, which is appended to event's perf_report
field.
Fully processed events are run through a batcher generator, which
batches together events with the same dt field into a single event
to be fed to the algo. The portfolio object is repeatedly
overwritten so that only the most recent snapshot of the universe
is sent to the algo.
"""
def __init__(self, algo, environment, sim_style):
self.algo = algo
self.sids = algo.get_sid_filter()
self.environment = environment
self.style = sim_style
self.algo_sim = None
def get_hash(self):
"""
There should only ever be one TSC in the system.
"""
return self.__class__.__name__ + hash_args()
def simulate(self, stream_in):
"""
Main generator work loop.
"""
# Simulate filling any open orders made by the previous run of
# the user's algorithm. Sets the txn field to true on any
# event that results in a filled order.
ordering_client = StatefulTransform(
TransactionSimulator,
self.sids,
style = self.style
)
with_filled_orders = ordering_client.transform(stream_in)
# Pipe the events with transactions to perf. This will remove
# the txn field added by TransactionSimulator and replace it
# with a portfolio object to be passed to the user's
# algorithm. Also adds a perf_message field which is usually
# none, but contains an update message once per day.
perf_tracker = StatefulTransform(
PerformanceTracker,
self.environment,
self.sids
)
with_portfolio = perf_tracker.transform(with_filled_orders)
# Pass the messages from perf along with the trading client's
# state into the algorithm for simulation. We provide the
# trading client so that the algorithm can place new orders
# into the client's order book.
self.algo_sim = AlgorithmSimulator(
with_portfolio,
ordering_client.state,
self.algo,
)
# The algorithm will yield a daily_results message (as
# calculated by the performance tracker) at the end of each
# day. It will also yield a risk report at the end of the
# simulation.
for message in self.algo_sim:
yield message
class AlgorithmSimulator(object):
def __init__(self, stream_in, order_book, algo):
self.stream_in = stream_in
# ==========
# Algo Setup
# ==========
# We extract the order book from the txn client so that
# the algo can place new orders.
self.order_book = order_book
self.algo = algo
self.sids = algo.get_sid_filter()
# Monkey patch the user algorithm to place orders in the
# TransactionSimulator's order book.
self.algo.set_order(self.order)
self.algo.set_logger(Logger("AlgoLog"))
# ==============
# Snapshot Setup
# ==============
# The algorithm's universe as of our most recent event.
self.universe = ndict()
for sid in self.sids:
self.universe[sid] = ndict()
self.universe.portfolio = None
# We don't have a datetime for the current snapshot until we
# receive a message.
self.simulation_dt = None
self.this_snapshot_dt = None
# =============
# Logging Setup
# =============
# Processor function for injecting the algo_dt into
# user prints/logs.
def inject_algo_dt(record):
record.extra['algo_dt'] = self.this_snapshot_dt
self.processor = Processor(inject_algo_dt)
# This is a class, which is instantiated later
# in run_algorithm. The class provides a generator.
self.stdout_capture = stdout_only_pipe
self.__generator = None
def __iter__(self):
return self
def next(self):
if self.__generator:
return self.__generator.next()
else:
self.__generator = self._gen()
return self.__generator.next()
def order(self, sid, amount):
"""
Closure to pass into the user's algo to allow placing orders
into the txn_sim's dict of open orders.
"""
assert sid in self.sids, "Order on invalid sid: %i" % sid
order = ndict({
'dt' : self.simulation_dt,
'sid' : sid,
'amount' : int(amount),
'filled' : 0
})
# Tell the user if they try to buy 0 shares of something.
if order.amount == 0:
zero_message = "Requested to trade zero shares of {sid}".format(
sid=order.sid
)
log.debug(zero_message)
# Don't bother placing orders for 0 shares.
return
# Add non-zero orders to the order book.
# !!!IMPORTANT SIDE-EFFECT!!!
# This modifies the internal state of the transaction
# simulator so that it can fill the placed order when it
# receives its next message.
self.order_book.place_order(order)
def _gen(self):
"""
Internal generator work loop.
"""
# Capture any output of this generator to stdout and pipe it
# to a logbook interface. Also inject the current algo
# snapshot time to any log record generated.
with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''):
# Call the user's initialize method.
self.algo.initialize()
for event in self.stream_in:
# Yield any perf messages received to be relayed back to
# the browser.
if event.perf_message:
yield event.perf_message
del event['perf_message']
if event.dt == "DONE":
if self.this_snapshot_dt:
# stop iteration happened
# mid-snapshot, so we have a universe
# snapshot that is not yet processed
# by the algorithm.
self.simulate_current_snapshot()
break
# This should only happen for the first event we run.
if self.simulation_dt == None:
self.simulation_dt = event.dt
# ======================
# Time Compression Logic
# ======================
if self.this_snapshot_dt != None:
self.update_current_snapshot(event)
# The algorithm has been missing events because it took
# too long processing. Update the universe with data from
# this event, then check if enough time has passed that we
# can start a new snapshot.
else:
self.update_universe(event)
if event.dt >= self.simulation_dt:
self.this_snapshot_dt = event.dt
def update_current_snapshot(self, event):
"""
Update our current snapshot of the universe. Call handle_data if
"""
# The new event matches our snapshot dt. Just update the
# universe and move on.
if event.dt == self.this_snapshot_dt:
self.update_universe(event)
# The new event does not match our snapshot.
else:
self.simulate_current_snapshot()
# Once we've finished simulating the old snapshot,
# we can update the universe with the new event.
self.update_universe(event)
# The current event is later than the simulation time,
# which means the algorithm finished quickly enough to
# receive the new event. Start a new snapshot with this
# event's dt.
if event.dt >= self.simulation_dt:
self.this_snapshot_dt = event.dt
# The algorithm spent enough time processing that it
# missed the new event. Wait to start a new snapshot until
# the events catch up to the algo's simulated dt.
else:
self.this_snapshot_dt = None
def simulate_current_snapshot(self):
"""
Run the user's algo against our current snapshot and update the algo's
simulated time.
"""
start_tic = datetime.now()
self.algo.handle_data(self.universe)
stop_tic = datetime.now()
# How long did you take?
delta = stop_tic - start_tic
# Update the simulation time.
self.simulation_dt = self.this_snapshot_dt + delta
def update_universe(self, event):
"""
Update the universe with new event information.
"""
# Update our portfolio.
self.universe.portfolio = event.portfolio
# Update our knowledge of this event's sid
for field in event.keys():
self.universe[event.sid][field] = event[field]
+168 -106
View File
@@ -2,16 +2,21 @@
Generator versions of transforms.
"""
import types
import pytz
from datetime import datetime
from copy import deepcopy
from datetime import datetime, timedelta
from collections import deque, defaultdict
from numbers import Number
from abc import ABCMeta, abstractmethod
from zipline import ndict
from zipline.utils.tradingcalendar import trading_days_between
from zipline.gens.utils import assert_sort_unframe_protocol, \
assert_transform_protocol, hash_args
class Passthrough(object):
FORWARDER = True
"""
Trivial class for forwarding events.
"""
@@ -19,11 +24,9 @@ class Passthrough(object):
pass
def update(self, event):
assert isinstance(event, ndict),"Bad event in Passthrough: %s" % event
assert event.has_key('sid'), "No sid in Passthrough: %s" % event
assert event.has_key('dt'), "No dt in Passthorughz: %s" % event
return event
pass
# Deprecated
def functional_transform(stream_in, func, *args, **kwargs):
"""
Generic transform generator that takes each message from an in-stream
@@ -40,131 +43,195 @@ def functional_transform(stream_in, func, *args, **kwargs):
assert_transform_protocol(out_value)
yield(namestring, out_value)
def stateful_transform(stream_in, tnfm_class, *args, **kwargs):
class StatefulTransform(object):
"""
Generic transform generator that takes each message from an in-stream
and sorts it to a state class. For each call to update, the state
class must produce a message to be fed downstream.
Generic transform generator that takes each message from an
in-stream and passes it to a state object. For each call to
update, the state class must produce a message to be fed
downstream. Any transform class with the FORWARDER class variable
set to true will forward all fields in the original message.
Otherwise only dt, tnfm_id, and tnfm_value are forwarded.
"""
assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \
def __init__(self, tnfm_class, *args, **kwargs):
assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \
"Stateful transform requires a class."
assert tnfm_class.__dict__.has_key('update'), \
assert tnfm_class.__dict__.has_key('update'), \
"Stateful transform requires the class to have an update method"
# Create an instance of our transform class.
state = tnfm_class(*args, **kwargs)
self.forward_all = tnfm_class.__dict__.get('FORWARDER', False)
self.update_in_place = tnfm_class.__dict__.get('UPDATER', False)
self.append_value = tnfm_class.__dict__.get('APPENDER', False)
# Generate the string associated with this generator's output.
namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
# You only one special behavior mode can be set.
assert sum(map(int, [self.forward_all,
self.update_in_place,
self.append_value])) <= 1
for message in stream_in:
assert_sort_unframe_protocol(message)
out_value = state.update(message)
assert_transform_protocol(out_value)
yield (namestring, out_value)
# Create an instance of our transform class.
self.state = tnfm_class(*args, **kwargs)
class MovingAverage(object):
"""
Class that maintains a dictionary from sids to EventWindows
Upon receipt of each message we update the
corresponding window and return the calculated average.
# Create the string associated with this generator's output.
self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
def get_hash(self):
return self.namestring
def transform(self, stream_in):
return self._gen(stream_in)
def _gen(self, stream_in):
# IMPORTANT: Messages may contain pointers that are shared with
# other streams, so we only manipulate copies.
for message in stream_in:
# allow upstream generators to yield None to avoid
# blocking.
if message == None:
continue
#TODO: refactor this to avoid unnecessary copying.
assert_sort_unframe_protocol(message)
message_copy = deepcopy(message)
# Same shared pointer issue here as above.
tnfm_value = self.state.update(deepcopy(message_copy))
# FORWARDER flag means we want to keep all original
# values, plus append tnfm_id and tnfm_value. Used for
# preserving the original event fields when our output
# will be fed into a merge.
if self.forward_all:
out_message = message_copy
out_message.tnfm_id = self.namestring
out_message.tnfm_value = tnfm_value
yield out_message
# UPDATER flag should be used for transforms that
# side-effectfully modify the event they are passed.
# Updated messages are passed along exactly as they are
# returned to use by our state class. Useful for chaining
# specific transforms that won't be fed to a merge. (See
# the implementation of TradeSimulationClient for example
# usage of this flag with PerformanceTracker and
# TransactionSimulator.
elif self.update_in_place:
yield tnfm_value
# APPENDER flag should be used to add a single new
# key-value pair to the event. The new key is this
# transform's namestring, and it's value is the value
# returned by state.update(event). This is almost
# identical to the behavior of FORWARDER, except we
# compress the two calculated values (tnfm_id, and
# tnfm_value) into a single field. This mode is used by
# the sequential_transforms composite.
elif self.append_value:
out_message = message_copy
out_message[self.namestring] = tnfm_value
yield out_message
# If no flags are set, we create a new message containing
# just the tnfm_id, the event's datetime, and the
# calculated tnfm_value. This is the default behavior for
# a transform being fed into a merge.
else:
out_message = ndict()
out_message.tnfm_id = self.namestring
out_message.tnfm_value = tnfm_value
out_message.dt = message_copy.dt
yield out_message
class EventWindow:
"""
Abstract base class for transform classes that calculate iterative
metrics on events within a given timedelta. Maintains a list of
events that are within a certain timedelta of the most recent
tick. Calls self.handle_add(event) for each event added to the
window. Calls self.handle_remove(event) for each event removed
from the window. Subclass these methods along with init(*args,
**kwargs) to calculate metrics over the window.
def __init__(self, delta, fields):
The market_aware flag is used to toggle whether the eventwindow
calculates
See zipline/gens/mavg.py and zipline/gens/vwap.py for example
implementations of moving average and volume-weighted average
price.
"""
# Mark this as an abstract base class.
__metaclass__ = ABCMeta
def __init__(self, market_aware, days = None, delta = None):
self.market_aware = market_aware
self.days = days
self.delta = delta
self.fields = fields
# No way to pass arguments to the defaultdict factory, so we
# need to define a method to generate the correct EventWindows.
self.sid_windows = defaultdict(self.create_window)
def create_window(self):
"""Factory method for self.sid_windows."""
return EventWindow(self.delta, self.fields)
def update(self, event):
"""
Update the event window for this event's sid. Return an ndict from
tracked fields to averages.
"""
assert isinstance(event, ndict),"Bad event in MovingAverage: %s" % event
assert event.has_key('sid'), "No sid in MovingAverage: %s" % event
assert event.has_key('dt'), "No dt in MovingAverage: %s" % event
output = ndict({'sid': event.sid, 'dt': event.dt})
# This will create a new EventWindow if this is the first
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
averages = window.get_averages()
# Return the calculated averages along with
output.merge(averages)
return output
class EventWindow(object):
"""
Maintains a list of events that are within a certain timedelta
of the most recent tick. The expected use of this class is to
track events associated with a single sid. We provide simple
functionality for averages, but anything more complicated
should be handled by a containing class.
"""
def __init__(self, delta, fields):
self.ticks = deque()
self.delta = delta
self.fields = fields
self.totals = defaultdict(float)
# Market-aware mode only works with full-day windows.
if self.market_aware:
assert self.days and not self.delta,\
"Market-aware mode only works with full-day windows."
# Non-market-aware mode requires a timedelta.
else:
assert self.delta and not self.days, \
"Non-market-aware mode requires a timedelta."
# Set the behavior for dropping events from the back of the
# event window.
if self.market_aware:
self.drop_condition = self.out_of_market_window
else:
self.drop_condition = self.out_of_delta
@abstractmethod
def handle_add(self, event):
raise NotImplementedError()
@abstractmethod
def handle_remove(self, event):
raise NotImplementedError()
def __len__(self):
return len(self.ticks)
def update(self, event):
self.assert_well_formed(event)
# Add new event and increment totals.
self.ticks.append(event)
for field in self.fields:
self.totals[field] += event[field]
# We return a list of all out-of-range events we removed.
out_of_range = []
# Subclasses should override handle_add to define behavior for
# adding new ticks.
self.handle_add(event)
# Clear out expired events, decrementing totals.
# newest oldest
# | |
# V V
# Clear out any expired events. drop_condition changes depending
# on whether or not we are running in market_aware mode.
#
# oldest newest
# | |
# V V
while self.drop_condition(self.ticks[0].dt, self.ticks[-1].dt):
while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta:
# popleft removes and returns ticks[0]
# popleft removes and returns the oldest tick in self.ticks
popped = self.ticks.popleft()
# Decrement totals
for field in self.fields:
self.totals[field] -= popped[field]
# Add the popped element to the list of dropped events.
out_of_range.append(popped)
return out_of_range
# Subclasses should override handle_remove to define
# behavior for removing ticks.
self.handle_remove(popped)
def average(self, field):
assert field in self.fields
if len(self.ticks) == 0:
return 0.0
else:
return self.totals[field] / len(self.ticks)
def out_of_market_window(self, oldest, newest):
return trading_days_between(oldest, newest) >= self.days
def get_averages(self):
"""
Return an ndict of all our tracked averages.
"""
out = ndict()
# out.ticks = len(self.ticks)
for field in self.fields:
out[field] = self.average(field)
return out
def out_of_delta(self, oldest, newest):
return (newest - oldest) >= self.delta
# All event windows expect to receive events with datetime fields
# that arrive in sorted order.
def assert_well_formed(self, event):
assert isinstance(event, ndict), "Bad event in EventWindow:%s" % event
assert event.has_key('dt'), "Missing dt in EventWindow:%s" % event
@@ -173,8 +240,3 @@ class EventWindow(object):
# Something is wrong if new event is older than previous.
assert event.dt >= self.ticks[-1].dt, \
"Events arrived out of order in EventWindow: %s -> %s" % (event, self.ticks[0])
for field in self.fields:
assert event.has_key(field), \
"Event missing [%s] in EventWindow" % field
assert isinstance(event[field], Number), \
"Got %s for %s in EventWindow" % (event[field], field)
+44 -7
View File
@@ -1,6 +1,7 @@
import pytz
import numbers
from collections import OrderedDict
from hashlib import md5
from datetime import datetime
from itertools import izip_longest
@@ -16,8 +17,16 @@ def mock_raw_event(sid, dt):
}
return event
def mock_done(source_id):
return ndict({'dt': "DONE", "source_id" : source_id, 'type' : 0})
def mock_done(id):
return ndict({
'dt' : "DONE",
"source_id" : id,
'tnfm_id' : id,
'tnfm_value': None,
'type' : DATASOURCE_TYPE.DONE
})
done_message = mock_done
def alternate(g1, g2):
"""Specialized version of roundrobin for just 2 generators."""
@@ -27,16 +36,25 @@ def alternate(g1, g2):
if e2 != None:
yield e2
def roundrobin(*args):
def roundrobin(sources, namestrings):
"""
Takes N generators, pulling one element off each until all inputs
are empty.
"""
for elem_tuple in izip_longest(*args):
for value in elem_tuple:
if value != None:
yield value
assert len(sources) == len(namestrings)
mapping = OrderedDict(zip(namestrings, sources))
# While our generators have not been exhausted, pull elements
while mapping.keys() != []:
for namestring, source in mapping.iteritems():
try:
message = source.next()
# allow sources to yield None to avoid blocking.
if message:
yield message
except StopIteration:
yield done_message(namestring)
del mapping[namestring]
def hash_args(*args, **kwargs):
"""Define a unique string for any set of representable args."""
@@ -48,6 +66,25 @@ def hash_args(*args, **kwargs):
hasher.update(combined)
return hasher.hexdigest()
def create_trade(sid, price, amount, datetime, source_id = "test_factory"):
row = ndict({
'source_id' : source_id,
'type' : DATASOURCE_TYPE.TRADE,
'sid' : sid,
'dt' : datetime,
'price' : price,
'volume' : amount
})
return row
def sum_true(bool_iterable):
"""
Takes an iterable of boolean values and returns the number of
those values that are True.
"""
return sum(map(int, bool_iterable))
def assert_datasource_protocol(event):
"""Assert that an event meets the protocol for datasource outputs."""
+86
View File
@@ -0,0 +1,86 @@
from numbers import Number
from datetime import datetime, timedelta
from collections import defaultdict
from zipline import ndict
from zipline.gens.transform import EventWindow
class VWAP(object):
"""
Class that maintains a dictionary from sids to VWAPEventWindows.
"""
def __init__(self, market_aware, delta=None, days=None):
self.market_aware = market_aware
self.delta = delta
self.days = days
# Market-aware mode only works with full-day windows.
if self.market_aware:
assert self.days and not self.delta,\
"Market-aware mode only works with full-day windows."
# Non-market-aware mode requires a timedelta.
else:
assert self.delta and not self.days, \
"Non-market-aware mode requires a timedelta."
# No way to pass arguments to the defaultdict factory, so we
# need to define a method to generate the correct EventWindows.
self.sid_windows = defaultdict(self.create_window)
def create_window(self):
"""Factory method for self.sid_windows."""
return VWAPEventWindow(
self.market_aware,
days = self.days,
delta = self.delta
)
def update(self, event):
"""
Update the event window for this event's sid. Returns the
current vwap for the sid.
"""
# This will create a new EventWindow if this is the first
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
return window.get_vwap()
class VWAPEventWindow(EventWindow):
"""
Iteratively maintains a vwap for a single sid over a given
timedelta.
"""
def __init__(self, market_aware, days=None, delta=None):
EventWindow.__init__(self, market_aware, days, delta)
self.flux = 0.0
self.totalvolume = 0.0
# Subclass customization for adding new events.
def handle_add(self, event):
# Sanity check on the event.
self.assert_required_fields(event)
self.flux += event.volume * event.price
self.totalvolume += event.volume
# Subclass customization for removing expired events.
def handle_remove(self, event):
self.flux -= event.volume * event.price
self.totalvolume -= event.volume
def get_vwap(self):
"""
Return the calculated vwap for this sid.
"""
# By convention, vwap is None if we have no events.
if len(self.ticks) == 0:
return None
else:
return (self.flux / self.totalvolume)
# We need numerical price and volume to calculate a vwap.
def assert_required_fields(self, event):
assert isinstance(event.price, Number)
assert isinstance(event.volume, Number)
+5 -3
View File
@@ -2,15 +2,17 @@ import zmq
import zipline.protocol as zp
def gen_from_zmq(poller, unframe):
def gen_from_zmq(poller, unframe, namestring):
"""
A generator that takes an initialized zmq poller and yields
messages from the poller until it gets a zp.CONTROL_PROTOCOL.DONE.
"""
while True:
message = poller.recv()
if message = zp.CONTROL_PROTOCOL.DONE:
yield "DONE"
# Done protocol should now be a message type so that
# done messages can also have source_ids.
if message.type == zp.CONTROL_PROTOCOL.DONE:
yield done_message(message.source_id)
break
else:
yield unframe(message)
+19
View File
@@ -0,0 +1,19 @@
import zmq
import zipline.protocol as zp
def gen_from_pull_socket(socket_uri, context, unframe):
"""
A generator that takes a socket_uri, and yields
messages from the poller until it gets a zp.CONTROL_PROTOCOL.DONE.
"""
pull_socket = context.socket(zmq.PULL)
pull_socket.connect(socket_uri)
poller = zmq.Poller()
poller.register(pull_socket, zmq.POLLIN)
return gen_from_poller(poller, pull_socket, unframe)
# this generator needs to know about the source_ids coming in via
# the poller, and need to yield DONE messages for each
# source_id.
+188 -233
View File
@@ -59,106 +59,190 @@ before invoking simulate.
| __init__. |
+---------------------------------+
"""
import inspect
import logbook
import zipline.utils.factory as factory
from zipline.components import DataSource
from zipline.transforms import BaseTransform
import sys
import zmq
import os
from signal import SIGHUP, SIGINT
import multiprocessing
from setproctitle import setproctitle
from zipline.test_algorithms import TestAlgorithm
from zipline.components import TradeSimulationClient
from zipline.core.process import ProcessSimulator
from zipline.core.monitor import Controller
from zipline.finance.trading import SIMULATION_STYLE
from zipline.utils.log_utils import ZeroMQLogHandler, stdout_only_pipe
from zipline.utils import factory
log = logbook.Logger('Lines')
from zipline.test_algorithms import TestAlgorithm
from zipline.gens.composites import \
date_sorted_sources, merged_transforms, sequential_transforms
from zipline.gens.transform import Passthrough, StatefulTransform
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
from logbook import Logger, NestedSetup, Processor
import zipline.protocol as zp
log = Logger('Lines')
class CancelSignal(Exception):
def __init__(self):
pass
class SimulatedTrading(object):
"""
Zipline with::
- _no_ data sources.
- Trade simulation client, which is available to send callbacks on
events and also accept orders to be simulated.
- An order data source, which will receive orders from the trade
simulation client, and feed them into the event stream to be
serialized and order alongside all other data source events.
- transaction simulation transformation, which receives the order
events and estimates a theoretical execution price and volume.
def __init__(self,
sources,
transforms,
algorithm,
environment,
style,
results_socket_uri,
context,
sim_id):
All components in this zipline are subject to heartbeat checks and
a control monitor, which can kill the entire zipline in the event of
exceptions in one of the components or an external request to end the
simulation.
"""
self.date_sorted = date_sorted_sources(*sources)
self.transforms = transforms
# Formerly merged_transforms.
self.with_tnfms = sequential_transforms(self.date_sorted, *self.transforms)
self.trading_client = tsc(algorithm, environment, style)
self.gen = self.trading_client.simulate(self.with_tnfms)
self.results_uri = results_socket_uri
self.results_socket = None
self.context = context
self.sim_id = sim_id
def __init__(self, **config):
# optional process if we fork simulate into an
# independent process.
self.proc = None
self.send_sighup = False
self.logger = Logger(sim_id)
self.print_logger = Logger('Print')
# exit status flag
self.success = False
def simulate(self, blocking=True, send_sighup=False):
# for non-blocking,
if blocking:
self.run_gen()
else:
self.send_sighup = send_sighup
return self.fork_and_sim()
def fork_and_sim(self):
self.proc = multiprocessing.Process(target=self.run_gen)
self.proc.start()
return self.proc
def run_gen(self):
setproctitle(self.sim_id)
self.open()
if self.zmq_out:
with self.zmq_out.threadbound():
self.stream_results()
# if no log socket, just run the algo normally
else:
self.stream_results()
def stream_results(self):
assert self.results_socket, \
"Results socket must exist to stream results"
try:
for event in self.gen:
if event.has_key('daily_perf'):
msg = zp.PERF_FRAME(event)
else:
msg = zp.RISK_FRAME(event)
self.results_socket.send(msg)
self.signal_done()
self.success = True
except Exception as exc:
self.handle_exception(exc)
finally:
# not much to do besides log our exit.
self.close()
def signal_done(self):
# notify monitor we're done
done_frame = zp.DONE_FRAME('success')
self.results_socket.send(done_frame)
def close(self):
log.info("Closing Simulation: {id}".format(id=self.sim_id))
if self.proc and self.send_sighup:
ppid = os.getppid()
if self.success:
log.warning("Sending SIGHUP")
os.kill(ppid, SIGHUP)
else:
log.warning("Sending SIGINT")
os.kill(ppid, SIGINT)
def handle_exception(self, exc):
if isinstance(exc, CancelSignal):
# signal from monitor of an orderly shutdown,
# do nothing.
pass
else:
self.signal_exception(exc)
def signal_exception(self, exc=None):
"""
:param config: a dict with the following required properties::
All exceptions inside any component should boil back to
this handler.
- algorithm: a class that follows the algorithm protocol. See
:py:meth:`zipline.finance.trading.TradeSimulationClient.add_algorithm
for details.
- trading_environment: an instance of
:py:class:`zipline.trading.TradingEnvironment`
- allocator: an instance of
:py:class:`zipline.simulator.AddressAllocator`
- simulation_style: optional parameter that configures the
:py:class:`zipline.finance.trading.TransactionSimulator`. Expects
a SIMULATION_STYLE as defined in :py:mod:`zipline.finance.trading`
Will inform the system that the component has failed and how it
has failed.
"""
assert isinstance(config, dict)
self.algorithm = config['algorithm']
self.allocator = config['allocator']
self.trading_environment = config['trading_environment']
self.sim_style = config.get('simulation_style')
self.send_sighup = config.get('send_sighup', False)
exc_type, exc_value, exc_traceback = sys.exc_info()
try:
log.exception('{id} sending exception to result stream.'\
.format(id=self.sim_id))
msg = zp.EXCEPTION_FRAME(
exc_traceback,
exc_type.__name__,
exc_value.message
)
self.leased_sockets = []
self.sim_context = None
self.results_socket.send(msg)
sockets = self.allocate_sockets(7)
addresses = {
'sync_address' : sockets[0],
'data_address' : sockets[1],
'feed_address' : sockets[2],
'merge_address' : sockets[3],
# TODO: this refers to the results of the merge, a
# horribly confusing name for the socket.
'results_address' : sockets[4],
}
except:
log.exception("Exception while reporting simulation exception.")
self.con = Controller(
sockets[5],
sockets[6],
self.send_sighup
def open(self):
if not self.context:
self.context = zmq.Context()
if self.results_uri:
sock = self.context.socket(zmq.PUSH)
sock.connect(self.results_uri)
self.results_socket = sock
self.setup_logging()
def setup_logging(self):
assert self.results_socket
# The filter behavior is: matches are logged, mismatches
# are bubbled. If bubble is True, matches are also
# bubbled. Since we do not want user logs in our system
# logs, we set bubble to False.
self.zmq_out = ZeroMQLogHandler(
socket = self.results_socket,
filter = lambda r, h: r.channel in ['Print', 'AlgoLog'],
bubble=False
)
self.started = False
def join(self):
if self.proc:
self.proc.join()
self.sim = ProcessSimulator(addresses)
self.clients = {}
self.trading_client = TradeSimulationClient(
self.trading_environment,
self.sim_style,
config['results_socket']
)
self.add_client(self.trading_client)
# setup all sources
self.sources = {}
#setup transforms
self.transforms = {}
self.sim.register_controller( self.con )
self.trading_client.set_algorithm(self.algorithm)
def get_pids(self):
if self.proc:
return [self.proc.pid]
else:
return []
@staticmethod
def create_test_zipline(**config):
@@ -167,7 +251,6 @@ class SimulatedTrading(object):
- environment - a \
:py:class:`zipline.finance.trading.TradingEnvironment`
- allocator - a :py:class:`zipline.simulator.AddressAllocator`
- sid - an integer, which will be used as the security ID.
- order_count - the number of orders the test algo will place,
defaults to 100
@@ -182,10 +265,10 @@ class SimulatedTrading(object):
- simulation_style: optional parameter that configures the
:py:class:`zipline.finance.trading.TransactionSimulator`. Expects
a SIMULATION_STYLE as defined in :py:mod:`zipline.finance.trading`
- transforms: optional parameter that provides a list
of StatefulTransform objects.
"""
assert isinstance(config, dict)
allocator = config['allocator']
sid = config['sid']
#--------------------
@@ -217,6 +300,10 @@ class SimulatedTrading(object):
if not simulation_style:
simulation_style = SIMULATION_STYLE.FIXED_SLIPPAGE
zmq_context = config.get('zmq_context', None)
simulation_id = config.get('simulation_id', 'test_simulation')
results_socket_uri = config.get('results_socket_uri', None)
#-------------------
# Trade Source
#-------------------
@@ -230,6 +317,12 @@ class SimulatedTrading(object):
trade_count,
trading_environment
)
#-------------------
# Transforms
#-------------------
transforms = config.get('transforms', [])
#-------------------
# Create the Algo
#-------------------
@@ -242,157 +335,19 @@ class SimulatedTrading(object):
order_count
)
if config.has_key('results_socket'):
results_socket = config['results_socket']
else:
results_socket = None
#-------------------
# Simulation
#-------------------
zipline = SimulatedTrading(**{
'algorithm' : test_algo,
'trading_environment' : trading_environment,
'allocator' : allocator,
'simulation_style' : simulation_style,
'results_socket' : results_socket,
})
sim = SimulatedTrading(
[trade_source],
transforms,
test_algo,
trading_environment,
simulation_style,
results_socket_uri,
zmq_context,
simulation_id)
#-------------------
zipline.add_source(trade_source)
return zipline
def add_source(self, source):
"""
Adds the source to the zipline, sets the sid filter of the
source to the algorithm's sid filter.
"""
assert isinstance(source, DataSource)
self.check_started()
source.set_filter('sid', self.algorithm.get_sid_filter())
self.sim.register_components([source])
# ``id`` is name of source_id, ``get_id`` is the class name
self.sources[source.get_id] = source
def add_transform(self, transform):
assert isinstance(transform, BaseTransform)
self.check_started()
self.sim.register_components([transform])
self.transforms[transform.get_id] = transform
def add_client(self, client):
assert isinstance(client, TradeSimulationClient)
self.check_started()
self.sim.register_components([client])
self.clients[client.get_id] = client
def check_started(self):
if self.started:
raise ZiplineException("TradeSimulation", "You cannot add \
components after the simulation has begun.")
def get_cumulative_performance(self):
return self.trading_client.perf.cumulative_performance.to_dict()
def allocate_sockets(self, n):
"""
Allocate sockets local to this line, track them so
we can gc after test run.
"""
assert isinstance(n, int)
assert n > 0
leased = self.allocator.lease(n)
self.leased_sockets.extend(leased)
return leased
@property
def components(self):
"""
Return the component instances inside of this topology
"""
base = set(self.sim.components.values())
transforms = set(self.transforms.values())
sources = set(self.sources.values())
return base | transforms | sources
@property
def topology(self):
"""
Returns the Component names in the topology of the
backtest.
"""
# A complete topology is the union of three classes of
# components added individually to the simulation client
# at various places.
#
# base : ['FEED', 'MERGE', 'TRADING_CLIENT', 'PASSTHROUGH']
# transforms : ['vwap__01', ... ]
# sources : ['MongoTradeHistory', ... ]
base = set(self.sim.components.keys())
transforms = set(self.transforms.keys())
sources = set(self.sources.keys())
return base | transforms | sources
def setup_controller(self):
"""
Prepare the controller to manage the topology specified
by this line.
"""
self.con.manage(self.topology)
def simulate(self, blocking=True):
self.setup_controller()
self.started = True
self.sim_context = self.sim.simulate()
# If we're using a threaded simulator block on the pool
# of thread since we're only ever in a test and we don't
# generally monitor the state of the system as a hold at
# the supervisory layer
# TODO: better way of identifying concurrency substrate
if blocking:
for process in self.sim.subprocesses:
process.join()
@property
def is_success(self):
# TODO: other assertions?
if self.sim.did_clean_shutdown():
return True
else:
return False
#--------------------------------
# Component property accessors
#--------------------------------
def get_positions(self):
"""
returns current positions as a dict. draws from the cumulative
performance period in the performance tracker.
"""
perf = self.trading_client.perf.cumulative_performance
positions = perf.get_positions()
return positions
class ZiplineException(Exception):
def __init__(self, zipline_name, msg):
self.name = zipline_name
self.message = msg
def __str__(self):
return "Unexpected exception {line}: {msg}".format(
line=self.name,
msg=self.message
)
return sim
+12 -2
View File
@@ -131,7 +131,7 @@ from utils.date_utils import EPOCH, UN_EPOCH, epoch_now
# Control Protocol
# -----------------------
PRODUCTION_PREFIXES = ['PERF', 'RISK', 'EXCEPTION', 'CANCEL']
PRODUCTION_PREFIXES = ['PERF', 'RISK', 'EXCEPTION','CANCEL','DONE', 'LOG']
INVALID_CONTROL_FRAME = FrameExceptionFactory('CONTROL')
@@ -527,11 +527,15 @@ def EXCEPTION_FRAME(exception_tb, name, message):
rlist = []
for stack in stack_list:
filename = shorten_filename(stack[0])
# default the line to empty string rather than None
line = ''
if stack[3]:
line = stack[3]
rstack = {
'filename' : filename,
'lineno' : stack[1],
'method' : stack[2],
'line' : stack[3]
'line' : line
}
rlist.append(rstack)
result = {
@@ -570,6 +574,12 @@ def CANCEL_FRAME(date):
return BT_UPDATE_FRAME('CANCEL', result)
def DONE_FRAME(msg):
assert isinstance(msg, basestring), \
"Done message must be a string."
return BT_UPDATE_FRAME('DONE', msg)
def BT_UPDATE_FRAME(prefix, payload):
"""
+26
View File
@@ -221,6 +221,32 @@ class DivByZeroAlgorithm():
def get_sid_filter(self):
return [self.sid]
class TimeoutAlgorithm():
def __init__(self, sid):
self.sid = sid
self.incr = 0
def initialize(self):
pass
def set_order(self, order_callable):
pass
def set_logger(self, logger):
pass
def set_portfolio(self, portfolio):
pass
def handle_data(self, data):
if self.incr > 4:
import time
time.sleep(100)
pass
def get_sid_filter(self):
return [self.sid]
class TestPrintAlgorithm():
+1 -1
View File
@@ -95,7 +95,7 @@ HOLIDAYS = {
'july_4th' : datetime(2008 , 7 , 4 ),
'labor_day' : datetime(2008 , 9 , 1 ),
'tgiving' : datetime(2008 , 11 , 27),
'christmas' : datetime(2008 , 5 , 25),
'christmas' : datetime(2008 , 12 , 25),
}
# Create a rule to recur every weekday starting today
+15 -35
View File
@@ -12,7 +12,9 @@ from datetime import datetime, timedelta
import zipline.finance.risk as risk
import zipline.protocol as zp
from zipline.finance.sources import SpecificEquityTrades, RandomEquityTrades
from zipline.gens.tradegens import RandomEquityTrades
from zipline.gens.tradegens import SpecificEquityTrades
from zipline.gens.utils import create_trade
from zipline.finance.trading import TradingEnvironment
# TODO
@@ -69,16 +71,6 @@ def create_trading_environment(year=2006):
return trading_environment
def create_trade(sid, price, amount, datetime, source_id = "test_factory"):
row = zp.ndict({
'source_id' : source_id,
'type' : zp.DATASOURCE_TYPE.TRADE,
'sid' : sid,
'dt' : datetime,
'price' : price,
'volume' : amount
})
return row
def get_next_trading_dt(current, interval, trading_calendar):
next = current
@@ -89,8 +81,6 @@ def get_next_trading_dt(current, interval, trading_calendar):
return next
def create_trade_history(sid, prices, amounts, interval, trading_calendar):
trades = []
current = trading_calendar.first_open
@@ -222,29 +212,19 @@ def create_minutely_trade_source(sids, trade_count, trading_environment):
)
def create_trade_source(sids, trade_count, trade_time_increment, trading_environment):
trade_history = []
price = [10.1] * trade_count
volume = [100] * trade_count
args = tuple()
kwargs = {
'count' : trade_count,
'sids' : sids,
'start' : trading_environment.first_open,
'delta' : trade_time_increment,
'filter' : sids
}
source = SpecificEquityTrades(*args, **kwargs)
for sid in sids:
start_date = trading_environment.first_open
# TODO: do we need to set the trading environment's end to same dt as
# the last trade in the history?
#trading_environment.period_end = trade_history[-1].dt
generated_trades = create_trade_history(
sid,
price,
volume,
trade_time_increment,
trading_environment
)
trade_history.extend(generated_trades)
trade_history = sorted(trade_history, key=attrgetter('dt'))
#set the trading environment's end to same dt as the last trade in the
#history.
trading_environment.period_end = trade_history[-1].dt
source = SpecificEquityTrades(trade_history)
return source
+8 -6
View File
@@ -120,12 +120,6 @@ class ZeroMQLogHandler(Handler):
#can't be serialized by JSON, so we need to convert to
#unix epoch representation.
if record.time:
assert isinstance(record.time, datetime.datetime)
time = record.time.replace(tzinfo = pytz.utc)
#logbook measures time in utc already, no need to convert.
record.time = EPOCH(time)
#Do the same if algo_dt is a datetime object.
if record.extra.has_key('algo_dt'):
@@ -151,6 +145,14 @@ class ZeroMQLogHandler(Handler):
data[field] = record.extra[field]
else:
data[field] = None
if data['time']:
assert isinstance(data['time'], datetime.datetime)
time = data['time'].replace(tzinfo = pytz.utc)
#logbook measures time in utc already, no need to convert.
data['time'] = EPOCH(time)
return data
def emit(self, record):
+81 -19
View File
@@ -1,11 +1,14 @@
import multiprocessing
import zmq
import time
import zipline.protocol as zp
from datetime import datetime
import blist
from bson import ObjectId
from zipline.utils.date_utils import EPOCH
from itertools import izip
from logbook import FileHandler
from zipline.core.monitor import Monitor
def setup_logger(test, path='/var/log/zipline/zipline.log'):
test.log_handler = FileHandler(path)
@@ -29,6 +32,7 @@ def check_dict(test, a, b, label):
# ignore the extra fields used by dictshield
if key in ['progress']:
continue
test.assertTrue(a.has_key(key), "missing key at: " + label + "." + key)
test.assertTrue(b.has_key(key), "missing key at: " + label + "." + key)
a_val = a[key]
@@ -58,15 +62,22 @@ def check(test, a, b, label=None):
else:
test.assertEqual(a, b, "mismatch on path: " + label)
def check_excluded(test, a, excluded_keys=[]):
for key, value in a.iteritems():
test.assertTrue(key not in excluded_keys)
test.assertFalse(key.endswith('_id'), 'Avoid _id fields!')
test.assertFalse(isinstance(value, ObjectId))
if isinstance(value, dict):
check_excluded(test, value, excluded_keys)
def drain_zipline(test, zipline):
def drain_zipline(test, zipline, p_blocking=False):
assert test.ctx, "method expects a valid zmq context"
assert test.zipline_test_config, "method expects a valid test config"
assert isinstance(test.zipline_test_config, dict)
assert test.zipline_test_config['results_socket'], \
assert test.zipline_test_config['results_socket_uri'], \
"need to specify a socket address for logs/perf/risk"
test.receiver = create_receiver(
test.zipline_test_config['results_socket'],
test.zipline_test_config['results_socket_uri'],
test.ctx
)
# Bind and connect are asynch, so allow time for bind before
@@ -74,13 +85,12 @@ def drain_zipline(test, zipline):
time.sleep(1)
# start the simulation
zipline.simulate(blocking=False)
zipline.simulate(blocking=p_blocking)
output, transaction_count = drain_receiver(test.receiver)
# some processes will exit after the message stream is
# finished. We block here to avoid collisions with subsequent
# ziplines.
for process in zipline.sim.subprocesses:
process.join()
zipline.join()
return output, transaction_count
@@ -95,16 +105,15 @@ def drain_receiver(receiver):
transaction_count = 0
while True:
msg = receiver.recv()
if msg == str(zp.CONTROL_PROTOCOL.DONE):
update = zp.BT_UPDATE_UNFRAME(msg)
output.append(update)
if update['prefix'] == 'PERF':
transaction_count += \
len(update['payload']['daily_perf']['transactions'])
elif update['prefix'] == 'EXCEPTION':
break
elif update['prefix'] == 'DONE':
break
else:
update = zp.BT_UPDATE_UNFRAME(msg)
output.append(update)
if update['prefix'] == 'PERF':
transaction_count += \
len(update['payload']['daily_perf']['transactions'])
elif update['prefix'] == 'EXCEPTION':
break
receiver.close()
del receiver
@@ -115,9 +124,6 @@ def drain_receiver(receiver):
def assert_single_position(test, zipline):
output, transaction_count = drain_zipline(test, zipline)
test.assertTrue(zipline.sim.ready())
test.assertFalse(zipline.sim.exception)
test.assertEqual(
test.zipline_test_config['order_count'],
transaction_count
@@ -126,7 +132,8 @@ def assert_single_position(test, zipline):
# the final message is the risk report, the second to
# last is the final day's results. Positions is a list of
# dicts.
closing_positions = output[-2]['payload']['daily_perf']['positions']
perfs = [x for x in output if x['prefix'] == 'PERF']
closing_positions = perfs[-2]['payload']['daily_perf']['positions']
test.assertEqual(
len(closing_positions),
@@ -140,3 +147,58 @@ def assert_single_position(test, zipline):
sid,
"Portfolio should have one position in " + str(sid)
)
def launch_component(component):
proc = multiprocessing.Process(target=component.run)
proc.start()
return proc
def launch_monitor(monitor):
proc = multiprocessing.Process(target=monitor.run)
proc.start()
return proc
def create_monitor(allocator):
sockets = allocator.lease(3)
mon = Monitor(
# pub socket
sockets[0],
# route socket
sockets[1],
# exception socket to match tradesimclient's result
# socket, because we want to relay exceptions to the
# same listener
sockets[2],
# this controller is expected to run in a test, so no
# need to signal the parent process on success or error.
send_sighup=False
)
return mon
class ExceptionSource(object):
def __init__(self):
pass
def get_hash(self):
return "ExceptionSource"
def __iter__(self):
return self
def next(self):
5 / 0
class ExceptionTransform(object):
def __init__(self):
pass
def get_hash(self):
return "ExceptionTransform"
def update(self, event):
assert False, "An assertion message"
+358
View File
@@ -0,0 +1,358 @@
import pytz
from datetime import datetime, timedelta
from dateutil import rrule
from zipline.utils.date_utils import utcnow
def market_opens(start, end, inclusive=False):
"""
Returns all market opens between the start date and the end date.
Must use utc-stamped datetimes.
"""
return opens.between(start, end, inc=inclusive)
def market_closes(start, end, inclusive=False):
"""
Returns all market closes between the start date and the end date.
Must use utc-stamped datetimes.
"""
return closes.between(start, end, inc=inclusive)
def trading_days_between(start, end):
"""
Calculate the number of "complete" trading days between two
events. We define this as the number of market opens that
occurred between start and end, with the caveat that we subtract 1
from this total if end falls on the same day as the last market
open and end occurs earlier in its own day than start. This
reflects the fact that we haven't completed a full day
corresponding to the last market open.
Examples:
1.)
start = Tuesday, Aug 7, 2012, 1:00 pm
end = Wednesday, Aug 8, 2012, 1:30 pm
There is one market open between these dates, on the morning of
Wednesday the 8th. This falls on the same calendar day as end,
but end is later in the day than start, so we count this as a full
day. The correct output is 1.
2.)
start = Tuesday, Aug 7, 2012, 1:30 pm
end = Wednesday, Aug 8, 2012, 1:00 pm
There is one market open between these dayes, on the morning of
Wednesday the 8th. This falls on the same calendar day as end,
and end is earlier in the day than start, so we do not count this
day as completed. The correct output is 0.
3.)
start = Tuesday, Aug 7, 2012, 1:00 pm
end = Saturday, Aug 11, 2012, 1:30 pm
There are 3 market opens between these dates, occurring on
Wednesday, Thursday, and Friday. The last open is not on
the same day as end, so we simply return 3
4.)
start = Tuesday, Aug 7, 2012, 1:30 pm
end = Monday, Aug, 13, 2012, 1:00 pm
There are 4 market opens between these dates, occurring on
Wednesday, Thursday, Friday, and the following Monday. The
last open occurs on the same calendar day as end, and end
is earlier in the day than start, so we do not count the
last market day as completed. The correct output is 3 days.
"""
# Calculate the number of opens between the events.
opens = (market_opens(start, end))
days_between = len(opens)
if days_between == 0:
return days_between
# If end falls on the same day as an open, subtract 1 from the
# total if end is earlier in its respective day than start.
last_open = opens[-1]
if last_open.date() == end.date() and earlier_in_day(end, start):
days_between -=1
return days_between
def earlier_in_day(d1, d2):
"""
Return true if d1 falls earlier in its own day than d2.
"""
return d1.time() < d2.time()
WEEKDAYS = [rrule.MO, rrule.TU, rrule.WE, rrule.TH, rrule.FR]
# Recurrence rule that generates all market opens since Jan 1, 1970.
# This does not exclude holidays.
market_opens_with_holidays = rrule.rrule(
rrule.DAILY,
byweekday=WEEKDAYS,
byhour = 14,
byminute = 30,
cache = True,
dtstart=datetime(2000, 1, 1, tzinfo = pytz.utc),
until=datetime(2014 , 1, 1, tzinfo = pytz.utc)
)
# Recurrence rule that generates all market closes since Jan 1, 1970.
# This does not exclude holidays.
market_closes_with_holidays = rrule.rrule(
rrule.DAILY,
byweekday=WEEKDAYS,
byhour = 21,
byminute = 0,
cache = True,
dtstart=datetime(2001, 1, 1, tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
# Recurrence rules for excluding the market open/close on new years.
new_years_opens = rrule.rrule(
rrule.MONTHLY,
byyearday = 1,
byhour = 14,
byminute = 30,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
new_years_closes = rrule.rrule(
rrule.MONTHLY,
byyearday = 1,
byhour = 21,
byminute = 0,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
# Recurrence rules for excluding MLK day. It is always the third
# monday in January.
mlk_opens = rrule.rrule(
rrule.MONTHLY,
bymonth = 1,
byweekday = (rrule.MO(3)),
byhour = 14,
byminute = 30,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
mlk_closes = rrule.rrule(
rrule.MONTHLY,
bymonth = 1,
byweekday = (rrule.MO(+3)),
byhour = 21,
byminute = 0,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
# Recurrence rules for generating the market open/close for
# presidents' day. Presidents' day always occurs on the third monday
# of February.
presidents_day_opens = rrule.rrule(
rrule.MONTHLY,
bymonth = 2,
byweekday = (rrule.MO(3)),
byhour = 14,
byminute = 30,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
presidents_day_closes = rrule.rrule(
rrule.MONTHLY,
bymonth = 2,
byweekday = (rrule.MO(3)),
byhour = 21,
byminute = 0,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
# Recurrence rules for generating the market open/close for good
# friday. Good friday always falls 2 days before easter, which
# thankfully is a built-in refernce in this module.
good_friday_opens = rrule.rrule(
rrule.DAILY,
byeaster = -2,
byhour = 14,
byminute = 30,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
good_friday_closes = rrule.rrule(
rrule.DAILY,
byeaster = -2,
byhour = 21,
byminute = 0,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
# Recurrence rules for generating the market open/close for memorial
# day. Memorial day always occurs on the last monday of May.
memorial_day_opens = rrule.rrule(
rrule.MONTHLY,
bymonth = 5,
byweekday = (rrule.MO(-1)),
byhour = 14,
byminute = 30,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
memorial_day_closes = rrule.rrule(
rrule.MONTHLY,
bymonth = 5,
byweekday = (rrule.MO(-1)),
byhour = 21,
byminute = 0,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
# Recurrence rules for generating the market open/close for July 4th.
july_4th_opens = rrule.rrule(
rrule.MONTHLY,
bymonth = 6,
bymonthday = 4,
byhour = 14,
byminute = 30,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
july_4th_closes = rrule.rrule(
rrule.MONTHLY,
bymonth = 6,
bymonthday = 4,
byhour = 21,
byminute = 0,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
# Recurrence rule for generating the market open/close for labor day.
# Labor day is always the first monday of September.
labor_day_opens = rrule.rrule(
rrule.MONTHLY,
bymonth = 9,
byweekday = (rrule.MO(1)),
byhour = 14,
byminute = 30,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
labor_day_closes = rrule.rrule(
rrule.MONTHLY,
bymonth = 9,
byweekday = (rrule.MO(1)),
byhour = 21,
byminute = 0,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
# Recurrence rule for generating the market open/close for
# thanksgiving. Thanksgiving always falls on the fourth thursday in
# November. (Who decides how these holidays work!?!)
thanksgiving_opens = rrule.rrule(
rrule.MONTHLY,
bymonth = 11,
byweekday = (rrule.TH(-1)),
byhour = 14,
byminute = 30,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
thanksgiving_closes = rrule.rrule(
rrule.MONTHLY,
bymonth = 11,
byweekday = (rrule.TH(-1)),
byhour = 21,
byminute = 0,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
# Recurrence relation for generating the market open/close for
# christmas. Christmas always occurs on december 25th.
christmas_opens = rrule.rrule(
rrule.MONTHLY,
bymonth = 12,
bymonthday = 25,
byhour = 14,
byminute = 30,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
christmas_closes = rrule.rrule(
rrule.MONTHLY,
bymonth = 12,
bymonthday = 25,
byhour = 21,
byminute = 0,
cache = True,
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
)
# All NYSE observed holidays.
holiday_opens = [
new_years_opens,
mlk_opens,
presidents_day_opens,
good_friday_opens,
memorial_day_opens,
july_4th_opens,
labor_day_opens,
thanksgiving_opens,
christmas_opens
]
holiday_closes = [
new_years_closes,
mlk_closes,
presidents_day_closes,
good_friday_closes,
memorial_day_closes,
july_4th_closes,
labor_day_closes,
thanksgiving_closes,
christmas_closes
]
# Valid market opens are given by all market opens minus holidays.
opens = rrule.rruleset(cache=True)
opens.rrule(market_opens_with_holidays)
for holiday_rule in holiday_opens:
opens.exrule(holiday_rule)
closes = rrule.rruleset(cache=True)
closes.rrule(market_closes_with_holidays)
for holiday_rule in holiday_closes:
closes.exrule(holiday_rule)
# This runs the calendar to load all data into a cache.
open_count = opens.count()
close_count = closes.count()