diff --git a/tests/test_components.py b/tests/test_components.py new file mode 100644 index 00000000..785a0ae2 --- /dev/null +++ b/tests/test_components.py @@ -0,0 +1,381 @@ +import zmq +import pytz +from pprint import pformat as pf +from datetime import datetime, timedelta + +from unittest2 import TestCase +from collections import defaultdict +from zipline.gens.composites import date_sorted_sources, merged_transforms + +from zipline.core.devsimulator import AddressAllocator +from zipline.gens.transform import Passthrough, StatefulTransform +from zipline.gens.mavg import MovingAverage +from zipline.gens.tradesimulation import TradeSimulationClient as tsc + +from zipline.utils.factory import create_trading_environment +from zipline.test_algorithms import TestAlgorithm + + +from zipline.utils.test_utils import ( + setup_logger, + teardown_logger, + create_monitor, + launch_monitor +) + +from zipline.core import Component +from zipline.protocol import ( + DATASOURCE_FRAME, + DATASOURCE_UNFRAME, + FEED_FRAME, + FEED_UNFRAME, + MERGE_FRAME, + MERGE_UNFRAME, + SIMULATION_STYLE, + PERF_FRAME, + BT_UPDATE_UNFRAME +) + +from zipline.gens.tradegens import SpecificEquityTrades + +import logbook +log = logbook.Logger('ComponentTestCase') + +allocator = AddressAllocator(1000) + + +class ComponentTestCase(TestCase): + + leased_sockets = defaultdict(list) + + def setUp(self): + self.zipline_test_config = { + 'allocator' : allocator, + 'sid' : 133, + 'devel' : False, + 'results_socket' : allocator.lease(1)[0], + 'simulation_style' : SIMULATION_STYLE.FIXED_SLIPPAGE + } + self.ctx = zmq.Context() + setup_logger(self) + + count = 250 + filter = [2,3] + #Set up source a. One minute between events. + args_a = tuple() + kwargs_a = { + 'count' : 2*count, + 'sids' : [1,2,3], + 'start' : datetime(2002,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(hours = 6), + 'filter' : filter + } + self.source_a = SpecificEquityTrades(*args_a, **kwargs_a) + + #Set up source b. Two minutes between events. + args_b = tuple() + kwargs_b = { + 'count' : count, + 'sids' : [2,3,4], + 'start' : datetime(2002,1,3,14, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 5), + 'filter' : filter + } + self.source_b = SpecificEquityTrades(*args_b, **kwargs_b) + + self.environment = create_trading_environment(year = 2002) + + + + def tearDown(self): + teardown_logger(self) + + def test_source(self): + monitor = create_monitor(allocator) + socket_uri = allocator.lease(1)[0] + count = 100 + + filter = [1,2,3,4] + #Set up source a. One minute between events. + args_a = tuple() + kwargs_a = { + 'sids' : [1,2], + 'start' : datetime(2012,6,6,0,tzinfo=pytz.utc), + 'delta' : timedelta(minutes = 1), + 'filter' : filter, + 'count' : count + } + + trade_gen = SpecificEquityTrades(*args_a, **kwargs_a) + + + comp_a = Component( + trade_gen, + monitor, + socket_uri, + DATASOURCE_FRAME, + DATASOURCE_UNFRAME, + "source_a" + ) + + mon_proc = launch_monitor(monitor) + + for event in comp_a: + log.info(event) + + # wait for the sending process to exit + comp_a.proc.join() + mon_proc.join() + + def test_sort(self): + monitor = create_monitor(allocator) + socket_uris = allocator.lease(3) + count = 100 + + filter = [1,2,3,4] + #Set up source a. One minute between events. + args_a = tuple() + kwargs_a = { + 'sids' : [1,2], + 'start' : datetime(2012,6,6,0,tzinfo=pytz.utc), + 'delta' : timedelta(minutes = 1), + 'filter' : filter, + 'count' : count + } + trade_gen_a = SpecificEquityTrades(*args_a, **kwargs_a) + + #Set up source b. Two minutes between events. + args_b = tuple() + kwargs_b = { + 'sids' : [2], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 1), + 'filter' : filter, + 'count' : count + } + trade_gen_b = SpecificEquityTrades(*args_b, **kwargs_b) + + #Set up source c. Three minutes between events. + args_c = tuple() + kwargs_c = { + 'sids' : [3], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 1), + 'filter' : filter, + 'count' : count + } + + trade_gen_c = SpecificEquityTrades(*args_c, **kwargs_c) + + + comp_a = Component( + trade_gen_a, + monitor, + socket_uris[0], + DATASOURCE_FRAME, + DATASOURCE_UNFRAME, + trade_gen_a.get_hash() + ) + + comp_b = Component( + trade_gen_b, + monitor, + socket_uris[1], + DATASOURCE_FRAME, + DATASOURCE_UNFRAME, + trade_gen_b.get_hash() + ) + + comp_c = Component( + trade_gen_c, + monitor, + socket_uris[2], + DATASOURCE_FRAME, + DATASOURCE_UNFRAME, + trade_gen_c.get_hash() + ) + + sources = [comp_a, comp_b, comp_c] + + sorted_out = date_sorted_sources(*sources) + + mon_proc = launch_monitor(monitor) + + prev = None + sort_count = 0 + for msg in sorted_out: + if prev: + self.assertTrue(msg.dt >= prev.dt, \ + "Messages should be in date ascending order") + prev = msg + sort_count += 1 + + self.assertEqual(count*3, sort_count) + + # wait for processes to finish + comp_a.proc.join() + comp_b.proc.join() + comp_c.proc.join() + mon_proc.join() + + + def test_full(self): + monitor = create_monitor(allocator) + + # ------------------------ + # Run sources in dedicated processes + comp_a = Component( + self.source_a, + monitor, + allocator.lease(1)[0], + DATASOURCE_FRAME, + DATASOURCE_UNFRAME, + self.source_a.get_hash() + ) + + comp_b = Component( + self.source_b, + monitor, + allocator.lease(1)[0], + DATASOURCE_FRAME, + DATASOURCE_UNFRAME, + self.source_b.get_hash() + ) + + # Date sort the sources, and run the sort in a dedicated + # process + sources = [comp_a, comp_b] + + sorted_out = date_sorted_sources(*sources) + + sorted = Component( + sorted_out, + monitor, + allocator.lease(1)[0], + FEED_FRAME, + FEED_UNFRAME, + "sort" + ) + + + passthrough = StatefulTransform(Passthrough) + mavg_price = StatefulTransform( + MovingAverage, + ['price'], + market_aware = False, + delta=timedelta(minutes = 20) + ) + + merged_gen = merged_transforms(sorted, passthrough, mavg_price) + + merged = Component( + merged_gen, + monitor, + allocator.lease(1)[0], + MERGE_FRAME, + MERGE_UNFRAME, + "merge" + ) + + algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3]) + + style = SIMULATION_STYLE.FIXED_SLIPPAGE + + trading_client = tsc(algo, self.environment, style) + tsc_gen = trading_client.simulate(merged) + + tsc_comp = Component( + tsc_gen, + monitor, + allocator.lease(1)[0], + PERF_FRAME, + BT_UPDATE_UNFRAME, + "tsc" + ) + mon_proc = launch_monitor(monitor) + for message in tsc_comp: + log.info(pf(message)) + + + # wait for processes to finish + comp_a.proc.join() + comp_b.proc.join() + sorted.proc.join() + merged.proc.join() + tsc_comp.proc.join() + mon_proc.join() + return + + def test_single_thread(self): + + #Set up source c. Three minutes between events. + + sorted = date_sorted_sources(self.source_a, self.source_b) + + passthrough = StatefulTransform(Passthrough) + mavg_price = StatefulTransform( + MovingAverage, + ['price'], + market_aware=False, + delta=timedelta(minutes = 20), + ) + + merged = merged_transforms(sorted, passthrough, mavg_price) + + algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3]) + style = SIMULATION_STYLE.FIXED_SLIPPAGE + + trading_client = tsc(algo, self.environment, style) + for message in trading_client.simulate(merged): + log.info(pf(message)) + + def test_compound(self): + monitor = create_monitor(allocator) + + sorted_out = date_sorted_sources(self.source_a, self.source_b) + + sorted = Component( + sorted_out, + monitor, + allocator.lease(1)[0], + FEED_FRAME, + FEED_UNFRAME, + "feed" + ) + + passthrough = StatefulTransform(Passthrough) + mavg_price = StatefulTransform( + MovingAverage, + ['price'], + market_aware = False, + delta=timedelta(minutes = 20) + ) + + merged_gen = merged_transforms(sorted, passthrough, mavg_price) + + merged = Component( + merged_gen, + monitor, + allocator.lease(1)[0], + MERGE_FRAME, + MERGE_UNFRAME, + "merge" + ) + + algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3]) + style = SIMULATION_STYLE.FIXED_SLIPPAGE + + trading_client = tsc(algo, self.environment, style) + tsc_gen = trading_client.simulate(merged) + + + mon_proc = launch_monitor(monitor) + for message in tsc_gen: + log.info(pf(message)) + + + # wait for processes to finish + sorted.proc.join() + merged.proc.join() + mon_proc.join() + return diff --git a/tests/test_exception_handling.py b/tests/test_exception_handling.py index b33a2acb..d1561837 100644 --- a/tests/test_exception_handling.py +++ b/tests/test_exception_handling.py @@ -7,30 +7,30 @@ from zipline.test_algorithms import ExceptionAlgorithm, DivByZeroAlgorithm from zipline.finance.trading import SIMULATION_STYLE from zipline.core.devsimulator import AddressAllocator from zipline.lines import SimulatedTrading +from zipline.gens.transform import StatefulTransform from zipline.utils.test_utils import \ drain_zipline, \ check, \ setup_logger, \ - teardown_logger + teardown_logger, \ + ExceptionSource, \ + ExceptionTransform DEFAULT_TIMEOUT = 15 # seconds EXTENDED_TIMEOUT = 90 allocator = AddressAllocator(1000) - class ExceptionTestCase(TestCase): leased_sockets = defaultdict(list) def setUp(self): self.zipline_test_config = { - 'allocator' : allocator, - 'sid' : 133, - 'devel' : False, - 'results_socket' : allocator.lease(1)[0], - 'simulation_style' : SIMULATION_STYLE.FIXED_SLIPPAGE + 'sid' : 133, + 'results_socket_uri' : allocator.lease(1)[0], + 'simulation_style' : SIMULATION_STYLE.FIXED_SLIPPAGE } self.ctx = zmq.Context() setup_logger(self) @@ -39,6 +39,39 @@ class ExceptionTestCase(TestCase): self.ctx.term() teardown_logger(self) + def test_datasource_exception(self): + self.zipline_test_config['trade_source'] = ExceptionSource() + zipline = SimulatedTrading.create_test_zipline( + **self.zipline_test_config + ) + output, _ = drain_zipline(self, zipline) + assert len(output) == 1 + assert output[0]['prefix'] == 'EXCEPTION' + message = output[0]['payload'] + for field in ['date', 'message', 'name', 'stack']: + assert field in message.keys() + + assert message['message'] == 'integer division or modulo by zero' + assert message['name'] == 'ZeroDivisionError' + + def test_tranform_exception(self): + exc_tnfm = StatefulTransform(ExceptionTransform) + self.zipline_test_config['transforms'] = [exc_tnfm] + + zipline = SimulatedTrading.create_test_zipline( + **self.zipline_test_config + ) + output, _ = drain_zipline(self, zipline) + assert len(output) == 1 + assert output[0]['prefix'] == 'EXCEPTION' + message = output[0]['payload'] + for field in ['date', 'message', 'name', 'stack']: + assert field in message.keys() + + assert message['message'] == 'An assertion message' + assert message['name'] == 'AssertionError' + + def test_exception_in_init(self): # Simulation # ---------- @@ -52,19 +85,17 @@ class ExceptionTestCase(TestCase): **self.zipline_test_config ) output, _ = drain_zipline(self, zipline) - self.assertEqual(len(output), 1) + self.assertEqual(output[-1]['prefix'], 'EXCEPTION') payload = output[-1]['payload'] self.assertTrue(payload['date']) - del payload['date'] - check(self, payload, INITIALIZE_TB) - - self.assertTrue(zipline.sim.ready()) - self.assertFalse(zipline.sim.exception) - + self.assertEqual(payload['message'],'Algo exception in initialize') + self.assertEqual(payload['name'],'Exception') + # make sure our path shortening is working + self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py') + self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py') def test_exception_in_handle_data(self): - # Simulation # ---------- self.zipline_test_config['algorithm'] = \ @@ -78,16 +109,15 @@ class ExceptionTestCase(TestCase): ) output, _ = drain_zipline(self, zipline) - - self.assertEqual(len(output), 1) self.assertEqual(output[-1]['prefix'], 'EXCEPTION') payload = output[-1]['payload'] self.assertTrue(payload['date']) del payload['date'] - check(self, payload, HANDLE_DATA_TB) - - self.assertTrue(zipline.sim.ready()) - self.assertFalse(zipline.sim.exception) + self.assertEqual(payload['message'],'Algo exception in handle_data') + self.assertEqual(payload['name'],'Exception') + # make sure our path shortening is working + self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py') + self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py') def test_zerodivision_exception_in_handle_data(self): @@ -103,95 +133,13 @@ class ExceptionTestCase(TestCase): ) output, _ = drain_zipline(self, zipline) - self.assertEqual(len(output), 5) + self.assertEqual(output[-1]['prefix'], 'EXCEPTION') payload = output[-1]['payload'] self.assertTrue(payload['date']) del payload['date'] - check(self, payload, ZERO_DIV_TB) - - self.assertTrue(zipline.sim.ready()) - self.assertFalse(zipline.sim.exception) - - # TODO: - # - define more zipline failure modes: exception in other - # components, exception in Monitor, etc. write tests - # for those scenarios. - - - -INITIALIZE_TB =\ -{'message': 'Algo exception in initialize', - 'name': 'Exception', - 'stack': [{'filename': '/zipline/core/component.py', 'line': 'self._run()', 'lineno': 204, 'method': 'run'}, - {'filename': '/zipline/core/component.py', 'line': 'self.loop()', 'lineno': 195, 'method': '_run'}, - {'filename': '/zipline/core/component.py', 'line': 'self.do_work()', 'lineno': 235, 'method': 'loop'}, - {'filename': '/zipline/components/tradesimulation.py', - 'line': 'self.initialize_algo()', - 'lineno': 97, - 'method': 'do_work'}, - {'filename': '/zipline/components/tradesimulation.py', - 'line': 'self.do_op(self.algorithm.initialize)', - 'lineno': 80, - 'method': 'initialize_algo'}, - {'filename': '/zipline/components/tradesimulation.py', - 'line': 'callable_op(*args, **kwargs)', - 'lineno': 210, - 'method': 'do_op'}, - {'filename': '/zipline/test_algorithms.py', - 'line': 'raise Exception("Algo exception in initialize")', - 'lineno': 166, - 'method': 'initialize'}]} - -HANDLE_DATA_TB =\ -{ - 'message': 'Algo exception in handle_data', - 'name': 'Exception', - 'stack': [{'filename': '/zipline/core/component.py', 'line': 'self._run()', 'lineno': 204, 'method': 'run'}, - {'filename': '/zipline/core/component.py', 'line': 'self.loop()', 'lineno': 195, 'method': '_run'}, - {'filename': '/zipline/core/component.py', 'line': 'self.do_work()', 'lineno': 235, 'method': 'loop'}, - {'filename': '/zipline/components/tradesimulation.py', - 'line': 'self.process_event(event)', - 'lineno': 116, - 'method': 'do_work'}, - {'filename': '/zipline/components/tradesimulation.py', - 'line': 'self.run_algorithm()', - 'lineno': 164, - 'method': 'process_event'}, - {'filename': '/zipline/components/tradesimulation.py', - 'line': 'self.do_op(self.algorithm.handle_data, data)', - 'lineno': 186, - 'method': 'run_algorithm'}, - {'filename': '/zipline/components/tradesimulation.py', - 'line': 'callable_op(*args, **kwargs)', - 'lineno': 210, - 'method': 'do_op'}, - {'filename': '/zipline/test_algorithms.py', - 'line': 'raise Exception("Algo exception in handle_data")', - 'lineno': 187, - 'method': 'handle_data'}]} - - -ZERO_DIV_TB= \ -{'message': 'integer division or modulo by zero', - 'name': 'ZeroDivisionError', - 'stack': [{'filename': '/zipline/core/component.py', 'line': 'self._run()', 'lineno': 204, 'method': 'run'}, - {'filename': '/zipline/core/component.py', 'line': 'self.loop()', 'lineno': 195, 'method': '_run'}, - {'filename': '/zipline/core/component.py', 'line': 'self.do_work()', 'lineno': 235, 'method': 'loop'}, - {'filename': '/zipline/components/tradesimulation.py', - 'line': 'self.process_event(event)', - 'lineno': 116, - 'method': 'do_work'}, - {'filename': '/zipline/components/tradesimulation.py', - 'line': 'self.run_algorithm()', - 'lineno': 164, - 'method': 'process_event'}, - {'filename': '/zipline/components/tradesimulation.py', - 'line': 'self.do_op(self.algorithm.handle_data, data)', - 'lineno': 186, - 'method': 'run_algorithm'}, - {'filename': '/zipline/components/tradesimulation.py', - 'line': 'callable_op(*args, **kwargs)', - 'lineno': 210, - 'method': 'do_op'}, - {'filename': '/zipline/test_algorithms.py', 'line': '5/0', 'lineno': 218, 'method': 'handle_data'}]} + self.assertEqual(payload['message'],'integer division or modulo by zero') + self.assertEqual(payload['name'],'ZeroDivisionError') + # make sure our path shortening is working + self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py') + self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py') diff --git a/tests/test_feed.py b/tests/test_feed.py deleted file mode 100644 index 21e8afb3..00000000 --- a/tests/test_feed.py +++ /dev/null @@ -1,233 +0,0 @@ -from unittest2 import TestCase -from itertools import cycle, chain -from datetime import datetime, timedelta -from collections import deque - -from zipline import ndict -from zipline.gens.sort import \ - date_sort, \ - ready, \ - done, \ - queue_is_ready,\ - queue_is_done -from zipline.gens.utils import hash_args, alternate -from zipline.gens.tradegens import date_gen, SpecificEquityTrades -from zipline.gens.composites import date_sorted_sources - -import zipline.protocol as zp - -class HelperTestCase(TestCase): - - def setUp(self): - pass - - def tearDown(self): - pass - - def test_individual_queue_logic(self): - queue = deque() - # Empty queues are neither done nor ready. - assert not queue_is_ready(queue) - assert not queue_is_done(queue) - - queue.append(to_dt('foo')) - assert queue_is_ready(queue) - assert not queue_is_done(queue) - - - queue.appendleft(to_dt('DONE')) - assert queue_is_ready(queue) - - # Checking done when we have a message after done will trip an assert. - self.assertRaises(AssertionError, queue_is_done, queue) - - queue.pop() - assert queue_is_ready(queue) - assert queue_is_done(queue) - - def test_pop_logic(self): - sources = {} - ids = ['a', 'b', 'c'] - for id in ids: - sources[id] = deque() - - assert not ready(sources) - assert not done(sources) - - # All sources must have a message to be ready/done - sources['a'].append(to_dt("datetime")) - assert not ready(sources) - assert not done(sources) - sources['a'].pop() - - for id in ids: - sources[id].append(to_dt("datetime")) - - assert ready(sources) - assert not done(sources) - - for id in ids: - sources[id].appendleft(to_dt("DONE")) - - # ["DONE", message] will trip an assert in queue_is_done. - assert ready(sources) - self.assertRaises(AssertionError, done, sources) - - for id in ids: - sources[id].pop() - - assert ready(sources) - assert done(sources) - -class DateSortTestCase(TestCase): - - def setUp(self): - pass - - def tearDown(self): - pass - - def run_date_sort(self, events, expected, source_ids): - """ - Take a list of events, their source_ids, and an expected sorting. - Assert that date_sort's output agrees with expected. - """ - sort_gen = date_sort(events, source_ids) - l = list(sort_gen) - assert l == expected - - def test_single_source(self): - source_ids = ['a'] - # 100 events, increasing by a minute at a time. - type = zp.DATASOURCE_TYPE.TRADE - dates = list(date_gen(count = 100)) - dates.append("DONE") - - # [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)] - event_args = zip(cycle(source_ids), iter(dates), cycle([type])) - - # Turn event_args into proper events. - events = [mock_data_unframe(*args) for args in event_args] - - # We don't expected Feed to yield the last event. - expected = events[:-1] - - event_gen = (e for e in events) - - self.run_date_sort(event_gen, expected, source_ids) - - def test_multi_source(self): - source_ids = ['a', 'b'] - type = zp.DATASOURCE_TYPE.TRADE - - # Set up source 'a'. Outputs 20 events with 2 minute deltas. - delta_a = timedelta(minutes = 2) - dates_a = list(date_gen(delta = delta_a, count = 20)) - dates_a.append("DONE") - - events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type])) - events_a = [mock_data_unframe(*args) for args in events_a_args] - - # Set up source 'b'. Outputs 10 events with 1 minute deltas. - delta_b = timedelta(minutes = 1) - dates_b = list(date_gen(delta = delta_b, count = 10)) - dates_b.append("DONE") - - events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type])) - events_b = [mock_data_unframe(*args) for args in events_b_args] - - # The expected output is all non-DONE events in both a and b, - # sorted first by dt and then by source_id. - non_dones = events_a[:-1] + events_b[:-1] - expected = sorted(non_dones, compare_by_dt_source_id) - - # Alternating between a and b. - interleaved = alternate(iter(events_a), iter(events_b)) - self.run_date_sort(interleaved, expected, source_ids) - - # All of a, then all of b. - - sequential = chain(iter(events_a), iter(events_b)) - self.run_date_sort(sequential, expected, source_ids) - - def test_sorted_sources(self): - - filter = [1,2] - #Set up source a. One hour between events. - args_a = tuple() - kwargs_a = {'sids' : [1,2,3,4], - 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(hours = 1), - 'filter' : filter - } - #Set up source b. One day between events. - args_b = tuple() - kwargs_b = {'sids' : [1,2,3,4], - 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(days = 1), - 'filter' : filter - } - #Set up source c. One minute between events. - args_c = tuple() - kwargs_c = {'sids' : [1,2,3,4], - 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(minutes = 1), - 'filter' : filter - } - # Set up source d. This should produce no events because the - # internal sids don't match the filter. - args_d = tuple() - kwargs_d = {'sids' : [3,4], - 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(minutes = 1), - 'filter' : filter - } - - sources = (SpecificEquityTrades,) * 4 - source_args = (args_a, args_b, args_c, args_d) - source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d) - - # Generate our expected source_ids. - zip_args = zip(source_args, source_kwargs) - expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs) - for args, kwargs in zip_args] - - # Pipe our sources into sort. - sort_out = date_sorted_sources(sources, source_args, source_kwargs) - - # Read all the values from sort and assert that they arrive in - # the correct sorting with the expected hash values. - to_list = list(sort_out) - copy = to_list[:] - for e in to_list: - # All events should match one of our expected source_ids. - assert e.source_id in expected_ids - # But none of them should match source_d. - assert e.source_id != hash_args(*args_d, **kwargs_d) - - expected = sorted(copy, compare_by_dt_source_id) - assert to_list == expected - -def mock_data_unframe(source_id, dt, type): - event = ndict() - event.source_id = source_id - event.dt = dt - event.type = type - return event - -def to_dt(val): - return ndict({'dt': val}) - -def compare_by_dt_source_id(x,y): - if x.dt < y.dt: - return -1 - elif x.dt > y.dt: - return 1 - - elif x.source_id < y.source_id: - return -1 - elif x.source_id > y.source_id: - return 1 - - else: - return 0 diff --git a/tests/test_finance.py b/tests/test_finance.py index d108e019..e5f26240 100644 --- a/tests/test_finance.py +++ b/tests/test_finance.py @@ -11,7 +11,6 @@ from collections import defaultdict from nose.tools import timed import zipline.utils.factory as factory -import zipline.protocol as zp from zipline.test_algorithms import TestAlgorithm from zipline.finance.trading import TradingEnvironment @@ -19,10 +18,9 @@ from zipline.core.devsimulator import AddressAllocator from zipline.lines import SimulatedTrading from zipline.finance.performance import PerformanceTracker from zipline.utils.protocol_utils import ndict -from zipline.finance.trading import TransactionSimulator, SIMULATION_STYLE +from zipline.finance.trading import TransactionSimulator from zipline.utils.test_utils import \ drain_zipline, \ - check, \ setup_logger, \ teardown_logger,\ assert_single_position @@ -39,10 +37,8 @@ class FinanceTestCase(TestCase): def setUp(self): self.zipline_test_config = { - 'allocator' : allocator, - 'sid' : 133, - #'devel' : True, - 'results_socket' : allocator.lease(1)[0] + 'sid' : 133, + 'results_socket_uri' : allocator.lease(1)[0] } self.ctx = zmq.Context() @@ -60,7 +56,7 @@ class FinanceTestCase(TestCase): trading_environment ) prev = None - for trade in trade_source.event_list: + for trade in trade_source: if prev: self.assertTrue(trade.dt > prev.dt) prev = trade @@ -123,7 +119,6 @@ class FinanceTestCase(TestCase): self.zipline_test_config['order_count'] = 100 self.zipline_test_config['trade_count'] = 200 zipline = SimulatedTrading.create_test_zipline(**self.zipline_test_config) - assert_single_position(self, zipline) #@timed(DEFAULT_TIMEOUT) @@ -148,9 +143,6 @@ class FinanceTestCase(TestCase): ) output, transaction_count = drain_zipline(self, zipline) - self.assertTrue(zipline.sim.ready()) - self.assertFalse(zipline.sim.exception) - #check that the algorithm received no events self.assertEqual( 0, @@ -301,12 +293,12 @@ class FinanceTestCase(TestCase): # if present, expect transaction amounts to match orders exactly. complete_fill = params.get('complete_fill') + sid = 1 trading_environment = factory.create_trading_environment() - trade_sim = TransactionSimulator() + trade_sim = TransactionSimulator([sid]) price = [10.1] * trade_count volume = [100] * trade_count start_date = trading_environment.first_open - sid = 1 generated_trades = factory.create_trade_history( sid, @@ -330,7 +322,7 @@ class FinanceTestCase(TestCase): 'dt' : order_date }) - trade_sim.add_open_order(order) + trade_sim.place_order(order) order_date = order_date + order_interval # move after market orders to just after market next @@ -353,14 +345,13 @@ class FinanceTestCase(TestCase): self.assertEqual(order.amount, order_amount * alternator**i) - tracker = PerformanceTracker(trading_environment) + tracker = PerformanceTracker(trading_environment, [sid]) # this approximates the loop inside TradingSimulationClient transactions = [] for trade in generated_trades: if trade_delay: trade.dt = trade.dt + trade_delay - txn = trade_sim.apply_trade_to_open_orders(trade) if txn: transactions.append(txn) diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 5f55aaee..76bb6184 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -1,7 +1,7 @@ from zipline.utils.test_utils import setup_logger, teardown_logger from unittest2 import TestCase, skip -from zipline.core.monitor import Controller +from zipline.core.monitor import Monitor class TestMonitor(TestCase): def setUp(self): @@ -14,13 +14,15 @@ class TestMonitor(TestCase): def test_init(self): pub_socket = 'tcp://127.0.0.1:5000' route_socket = 'tcp://127.0.0.1:5001' + exception_socket = 'tcp://127.0.0.1:5002' - con = Controller(pub_socket, route_socket) - con.manage([]) + mon = Monitor(pub_socket, route_socket, exception_socket) + mon.manage([]) def test_init_topology(self): pub_socket = 'tcp://127.0.0.1:5000' route_socket = 'tcp://127.0.0.1:5001' + exception_socket = 'tcp://127.0.0.1:5002' - con = Controller(pub_socket, route_socket, ) - con.manage([ 'a', 'b', 'c', 'd' ]) + mon = Monitor(pub_socket, route_socket, exception_socket) + mon.manage([ 'a', 'b', 'c', 'd' ]) diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index 1a77818c..2f9c1df8 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -543,7 +543,10 @@ shares in position" self.trading_environment.capital_base = 1000.0 self.trading_environment.frame_index = ['sid', 'volume', 'dt', \ 'price', 'changed'] - perf_tracker = perf.PerformanceTracker(self.trading_environment) + perf_tracker = perf.PerformanceTracker( + self.trading_environment, + [sid, sid2] + ) for event in trade_history: #create a transaction for all but diff --git a/tests/test_protocol.py b/tests/test_protocol.py index c90b09dc..d1606ed4 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,8 +1,6 @@ """ Test the FRAME/UNFRAME functions in the sequence expected from ziplines. """ -import pytz - from unittest2 import TestCase from datetime import datetime, timedelta from collections import defaultdict @@ -10,10 +8,8 @@ from collections import defaultdict from nose.tools import timed import zipline.utils.factory as factory -from zipline.utils import logger import zipline.protocol as zp -from zipline.finance.sources import SpecificEquityTrades DEFAULT_TIMEOUT = 5 # seconds diff --git a/tests/test_sorting.py b/tests/test_sorting.py new file mode 100644 index 00000000..bec97e31 --- /dev/null +++ b/tests/test_sorting.py @@ -0,0 +1,259 @@ +import pytz + +from unittest2 import TestCase +from itertools import cycle, chain, izip, izip_longest +from datetime import datetime, timedelta +from collections import deque + +from zipline import ndict +from zipline.gens.sort import \ + date_sort, \ + ready, \ + done, \ + queue_is_ready,\ + queue_is_done +from zipline.gens.utils import hash_args, alternate, done_message +from zipline.gens.tradegens import date_gen, SpecificEquityTrades +from zipline.gens.composites import date_sorted_sources + +import zipline.protocol as zp + +class HelperTestCase(TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_individual_queue_logic(self): + queue = deque() + # Empty queues are neither done nor ready. + assert not queue_is_ready(queue) + assert not queue_is_done(queue) + + queue.append(to_dt('foo')) + assert queue_is_ready(queue) + assert not queue_is_done(queue) + + + queue.appendleft(to_dt('DONE')) + assert queue_is_ready(queue) + + # Checking done when we have a message after done will trip an assert. + self.assertRaises(AssertionError, queue_is_done, queue) + + queue.pop() + assert queue_is_ready(queue) + assert queue_is_done(queue) + + def test_pop_logic(self): + sources = {} + ids = ['a', 'b', 'c'] + for id in ids: + sources[id] = deque() + + assert not ready(sources) + assert not done(sources) + + # All sources must have a message to be ready/done + sources['a'].append(to_dt("datetime")) + assert not ready(sources) + assert not done(sources) + sources['a'].pop() + + for id in ids: + sources[id].append(to_dt("datetime")) + + assert ready(sources) + assert not done(sources) + + for id in ids: + sources[id].appendleft(to_dt("DONE")) + + # ["DONE", message] will trip an assert in queue_is_done. + assert ready(sources) + self.assertRaises(AssertionError, done, sources) + + for id in ids: + sources[id].pop() + + assert ready(sources) + assert done(sources) + +class DateSortTestCase(TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def run_date_sort(self, event_stream, expected, source_ids): + """ + Take a list of events, their source_ids, and an expected sorting. + Assert that date_sort's output agrees with expected. + """ + sort_out = date_sort(event_stream, source_ids) + for m1, m2 in izip_longest(sort_out, expected): + assert m1 == m2 + + def test_single_source(self): + + # Just using the built-in defaults. See + # zipline/gens/tradegens.py + source = SpecificEquityTrades() + expected = list(source) + source.rewind() + # The raw source doesn't handle done messaging, so we need to + # append a done message for sort to work properly. + with_done = chain(source, [done_message(source.get_hash())]) + self.run_date_sort(with_done, expected, [source.get_hash()]) + + def test_multi_source(self): + + filter = [2,3] + args_a = tuple() + kwargs_a = { + 'count' : 100, + 'sids' : [1,2,3], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 6), + 'filter' : filter + } + source_a = SpecificEquityTrades(*args_a, **kwargs_a) + + args_b = tuple() + kwargs_b = { + 'count' : 100, + 'sids' : [2,3,4], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 5), + 'filter' : filter + } + source_b = SpecificEquityTrades(*args_b, **kwargs_b) + + all_events = list(chain(source_a, source_b)) + + # The expected output is all events, sorted by dt with + # source_id as a tiebreaker. + expected = sorted(all_events, comp) + source_ids = [source_a.get_hash(), source_b.get_hash()] + + # Generating the events list consumes the sources. Rewind them + # for testing. + source_a.rewind() + source_b.rewind() + + # Append a done message to each source. + with_done_a = chain(source_a, [done_message(source_a.get_hash())]) + with_done_b = chain(source_b, [done_message(source_b.get_hash())]) + + interleaved = alternate(with_done_a, with_done_b) + + # Test sort with alternating messages from source_a and + # source_b. + self.run_date_sort(interleaved, expected, source_ids) + + source_a.rewind() + source_b.rewind() + with_done_a = chain(source_a, [done_message(source_a.get_hash())]) + with_done_b = chain(source_b, [done_message(source_b.get_hash())]) + + sequential = chain(with_done_a, with_done_b) + + # Test sort with all messages from a, followed by all messages + # from b. + + self.run_date_sort(sequential, expected, source_ids) + + + def test_sort_composite(self): + + filter = [1,2] + + #Set up source a. One hour between events. + args_a = tuple() + kwargs_a = { + 'count' : 100, + 'sids' : [1], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(hours = 1), + 'filter' : filter + } + source_a = SpecificEquityTrades(*args_a, **kwargs_a) + + #Set up source b. One day between events. + args_b = tuple() + kwargs_b = { + 'count' : 50, + 'sids' : [2], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(days = 1), + 'filter' : filter + } + source_b = SpecificEquityTrades(*args_b, **kwargs_b) + + #Set up source c. One minute between events. + args_c = tuple() + kwargs_c = { + 'count' : 150, + 'sids' : [1,2], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } + source_c = SpecificEquityTrades(*args_c, **kwargs_c) + # Set up source d. This should produce no events because the + # internal sids don't match the filter. + args_d = tuple() + kwargs_d = { + 'count' : 50, + 'sids' : [3], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } + source_d = SpecificEquityTrades(*args_d, **kwargs_d) + sources = [source_a, source_b, source_c, source_d] + hashes = [source.get_hash() for source in sources] + + sort_out = date_sorted_sources(*sources) + + # Read all the values from sort and assert that they arrive in + # the correct sorting with the expected hash values. + to_list = list(sort_out) + copy = to_list[:] + + # We should have 300 events (100 from a, 150 from b, 50 from c) + assert len(to_list) == 300 + + for e in to_list: + # All events should match one of our expected source_ids. + assert e.source_id in hashes + # But none of them should match source_d. + assert e.source_id != source_d.get_hash() + + # The events should be sorted by dt, with source_id as tiebreaker. + expected = sorted(copy, comp) + + assert to_list == expected + +def compare_by_dt_source_id(x,y): + if x.dt < y.dt: + return -1 + elif x.dt > y.dt: + return 1 + + elif x.source_id < y.source_id: + return -1 + elif x.source_id > y.source_id: + return 1 + else: + return 0 + +#Alias for ease of use +comp = compare_by_dt_source_id + +def to_dt(msg): + return ndict({'dt': msg}) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 1fe1ce3c..491df1b5 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,102 +1,266 @@ -from datetime import timedelta +import pytz + +from datetime import timedelta, datetime from collections import defaultdict from unittest2 import TestCase +from zipline import ndict + +from zipline.lines import SimulatedTrading + from zipline.utils.test_utils import setup_logger, teardown_logger +from zipline.utils.date_utils import utcnow + +from zipline.gens.tradegens import SpecificEquityTrades +from zipline.gens.transform import StatefulTransform, EventWindow +from zipline.gens.vwap import VWAP +from zipline.gens.mavg import MovingAverage +from zipline.gens.returns import Returns import zipline.utils.factory as factory -from zipline.finance.vwap import DailyVWAP, VWAPTransform -from zipline.finance.returns import ReturnsFromPriorClose -from zipline.finance.movingaverage import MovingAverage -from zipline.lines import SimulatedTrading -from zipline.core.devsimulator import AddressAllocator -allocator = AddressAllocator(1000) +def to_dt(msg): + return ndict({'dt': msg}) -class ZiplineWithTransformsTestCase(TestCase): - leased_sockets = defaultdict(list) +class NoopEventWindow(EventWindow): + """ + A no-op EventWindow subclass for testing the base EventWindow logic. + Keeps lists of all added and dropped events. + """ + def __init__(self, market_aware, days, delta): + EventWindow.__init__(self, market_aware, days, delta) + + self.added = [] + self.removed = [] + + def handle_add(self, event): + self.added.append(event) + + def handle_remove(self, event): + self.removed.append(event) + +class EventWindowTestCase(TestCase): def setUp(self): - # skip ahead 100 spots - allocator.lease(100) - self.trading_environment = factory.create_trading_environment() - self.zipline_test_config = { - 'allocator' : allocator, - 'sid' : 133, - 'devel' : True - } - setup_logger(self, '/var/log/qexec/qexed.log') + setup_logger(self) + + # Constants calling before open, during the day, and after + # close on a valid trading day. + self.pre_open = datetime(2012, 8, 7, 13, tzinfo = pytz.utc) + self.mid_day = datetime(2012, 8, 7, 15, tzinfo = pytz.utc) + self.post_close = datetime(2012, 8, 7, 22, tzinfo = pytz.utc) - def tearDown(self): - teardown_logger(self) + # Constants calling before open, during the day, and after + # close on a saturday. + self.pre_open_saturday = datetime(2012, 8, 11, 13, tzinfo = pytz.utc) + self.mid_day_saturday = datetime(2012, 8, 11, 15, tzinfo = pytz.utc) + self.post_close_saturday = datetime(2012, 8, 11, 22, tzinfo = pytz.utc) - def test_vwap_tnfm(self): - zipline = SimulatedTrading.create_test_zipline( - **self.zipline_test_config + # Constants calling before open, during the day, and after + # close on a holiday. + self.pre_open_holiday = datetime(2012, 12, 25, 13, tzinfo = pytz.utc) + self.mid_day_holiday = datetime(2012, 12, 25, tzinfo = pytz.utc) + self.post_close_holiday = datetime(2012, 12, 25, 22, tzinfo = pytz.utc) + + def test_event_window_with_timedelta(self): + + # Keep all events within a 5 minute window. + window = NoopEventWindow( + market_aware = False, + delta = timedelta(minutes = 5), + days = None ) - vwap = VWAPTransform("vwap_10", daycount=10) - zipline.add_transform(vwap) + now = utcnow() - zipline.simulate(blocking=True) + # 15 dates, increasing in 1 minute increments. + dates = [now + i * timedelta(minutes = 1) + for i in xrange(15)] - self.assertTrue(zipline.sim.ready()) - self.assertFalse(zipline.sim.exception) + # Turn the dates into the format required by EventWindow. + dt_messages = [to_dt(date) for date in dates] + + # Run all messages through the window and assert that we're adding + # and removing messages appropriately. We start the enumeration at 1 + # for convenience. + for num, message in enumerate(dt_messages, 1): + window.update(message) + + # Assert that we've added the correct number of events. + assert len(window.added) == num + + # Assert that we removed only events that fall outside (or + # on the boundary of) the delta. + for dropped in window.removed: + assert message.dt - dropped.dt >= timedelta(minutes = 5) + + def test_market_aware_window(self): + window = NoopEventWindow( + market_aware = True, + delta = None, + days = 1 + ) + dates = ([self.pre_open]*3) + dates += ([self.mid_day]*3) + dates += ([self.post_close]*3) + dates += [self.pre_open + timedelta(days = 1, seconds = 1)] + events = [to_dt(date) for date in dates] + + # Run the events. + for event in events: + window.update(event) + + # We should have removed the pre_open events on the first day. + # The rest should be intact. + + assert window.added == events + assert window.removed == events[0:3] + assert list(window.ticks) == events[3:] + + def test_market_aware_window_weekend(self): + window = NoopEventWindow( + market_aware = True, + delta = None, + days = 2 + ) + dates = [self.pre_open_saturday - timedelta(days = 1, seconds=1)] + dates += [self.mid_day_saturday - timedelta(days = 1, seconds=1)] + dates += [self.post_close_saturday - timedelta(days = 1, seconds=1)] + dates += [self.mid_day_saturday + timedelta(days = 1)] + + events = [to_dt(date) for date in dates] + + # Run the events. + for event in events: + window.update(event) + + # We shouldn't remove any events. + assert window.added == events + assert window.removed == [] + assert list(window.ticks) == events + + extra = to_dt(self.mid_day_saturday + timedelta(days = 2)) + window.update(extra) + + # We should remove only the first event. + assert window.removed == [events[0]] + assert list(window.ticks) == events[1:] + [extra] + + def tearDown(self): + setup_logger(self) class FinanceTransformsTestCase(TestCase): def setUp(self): self.trading_environment = factory.create_trading_environment() - setup_logger(self, '/var/log/qexec/qexec.log') + setup_logger(self) + + trade_history = factory.create_trade_history( + 133, + [10.0, 10.0, 11.0, 11.0], + [100, 100, 100, 300], + timedelta(days=1), + self.trading_environment + ) + self.source = SpecificEquityTrades(event_list=trade_history) def tearDown(self): self.log_handler.pop_application() def test_vwap(self): - trade_history = factory.create_trade_history( - 133, - [10.0, 10.0, 10.0, 11.0], - [100, 100, 100, 300], - timedelta(days=1), - self.trading_environment + vwap = StatefulTransform( + VWAP, + market_aware = False, + delta = timedelta(days = 2) ) + transformed = list(vwap.transform(self.source)) - vwap = DailyVWAP(days=2) - for trade in trade_history: - vwap.update(trade) - - self.assertEqual(vwap.vwap, 10.75) + # Output values + tnfm_vals = [message.tnfm_value for message in transformed] + # "Hand calculated" values. + expected = [ + (10.0 * 100) / 100.0, + ((10.0 * 100) + (10.0 * 100)) / (200.0), + # We should drop the first event here. + ((10.0 * 100) + (11.0 * 100)) / (200.0), + # We should drop the second event here. + ((11.0 * 100) + (11.0 * 300)) / (400.0) + ] + # Output should match the expected. + assert tnfm_vals == expected def test_returns(self): + # Daily returns. + returns = StatefulTransform(Returns, 1) + + transformed = list(returns.transform(self.source)) + tnfm_vals = [message.tnfm_value for message in transformed] + + # No returns for the first event because we don't have a + # previous close. + expected = [0.0, 0.0, 0.1, 0.0] + + assert tnfm_vals == expected + + # Two-day returns. An extra kink here is that the + # factory will automatically skip a weekend for the + # last event. Results shouldn't notice this blip. + trade_history = factory.create_trade_history( 133, - [10.0, 10.0, 10.0, 11.0], - [100, 100, 100, 300], + [10.0, 15.0, 13.0, 12.0, 13.0], + [100, 100, 100, 300, 100], timedelta(days=1), self.trading_environment ) + self.source = SpecificEquityTrades(event_list=trade_history) - returns = ReturnsFromPriorClose() - for trade in trade_history: - returns.update(trade) + returns = StatefulTransform(Returns, 2) + transformed = list(returns.transform(self.source)) + tnfm_vals = [message.tnfm_value for message in transformed] - self.assertEqual(returns.returns, .1) + expected = [ + 0.0, + 0.0, + (13.0 - 10.0) / 10.0, + (12.0 - 15.0) / 15.0, + (13.0 - 13.0) / 13.0 + ] + assert tnfm_vals == expected def test_moving_average(self): - trade_history = factory.create_trade_history( - 133, - [10.0, 10.0, 10.0, 11.0], - [100, 100, 100, 300], - timedelta(days=1), - self.trading_environment + + mavg = StatefulTransform( + MovingAverage, + market_aware = False, + fields = ['price', 'volume'], + delta = timedelta(days = 2), ) + transformed = list(mavg.transform(self.source)) + # Output values. + tnfm_prices = [message.tnfm_value.price for message in transformed] + tnfm_volumes = [message.tnfm_value.volume for message in transformed] - ma = MovingAverage(days=2) - for trade in trade_history: - ma.update(trade) + # "Hand-calculated" values + expected_prices = [ + ((10.0) / 1.0), + ((10.0 + 10.0) / 2.0), + # First event should get dropped here. + ((10.0 + 11.0) / 2.0), + # Second event should get dropped here. + ((11.0 + 11.0) / 2.0) + ] + expected_volumes = [ + ((100.0) / 1.0), + ((100.0 + 100.0) / 2.0), + # First event should get dropped here. + ((100.0 + 100.0) / 2.0), + # Second event should get dropped here. + ((100.0 + 300.0) / 2.0) + ] - - self.assertEqual(ma.average, 10.5) + assert tnfm_prices == expected_prices + assert tnfm_volumes == expected_volumes diff --git a/zipline/__init__.py b/zipline/__init__.py index 23bcca40..31272fcb 100644 --- a/zipline/__init__.py +++ b/zipline/__init__.py @@ -6,15 +6,9 @@ Zipline # it is a place to expose the public interfaces. import protocol # namespace -from core.monitor import Controller -from lines import SimulatedTrading -from core.host import ComponentHost from utils.protocol_utils import ndict __all__ = [ - SimulatedTrading, - Controller, - ComponentHost, protocol, ndict ] diff --git a/zipline/components/tradesimulation.py b/zipline/components/tradesimulation.py index aa8ca937..eabddaec 100644 --- a/zipline/components/tradesimulation.py +++ b/zipline/components/tradesimulation.py @@ -12,7 +12,7 @@ from zipline.utils.protocol_utils import ndict from zipline.utils.log_utils import ZeroMQLogHandler, stdout_only_pipe -from logbook import Logger, NestedSetup, Processor, queues +from logbook import Logger, NestedSetup, Processor log = logbook.Logger('TradeSimulation') @@ -20,7 +20,7 @@ log = logbook.Logger('TradeSimulation') class TradeSimulationClient(Component): - def init(self, trading_environment, sim_style, results_socket): + def init(self, trading_environment, sim_style, results_socket, algorithm): self.received_count = 0 self.prev_dt = None self.event_queue = None @@ -29,13 +29,20 @@ class TradeSimulationClient(Component): self.trading_environment = trading_environment self.current_dt = trading_environment.period_start self.last_iteration_dur = datetime.timedelta(seconds=0) - self.algorithm = None + self.algorithm = algorithm + self.algorithm.set_order(self.order) self.max_wait = datetime.timedelta(seconds=60) self.last_msg_dt = datetime.datetime.utcnow() - self.txn_sim = TransactionSimulator(sim_style) + self.txn_sim = TransactionSimulator( + open_orders={}, + style=sim_style + ) self.event_data = ndict() - self.perf = perf.PerformanceTracker(self.trading_environment) + self.perf = perf.PerformanceTracker( + self.trading_environment, + self.algorithm.get_sid_filter() + ) self.zmq_out = None self.results_socket = results_socket self.algo_initialized = False @@ -44,19 +51,6 @@ class TradeSimulationClient(Component): def get_id(self): return str(zp.FINANCE_COMPONENT.TRADING_CLIENT) - def set_algorithm(self, algorithm): - """ - :param algorithm: must implement the algorithm protocol. See - :py:mod:`zipline.test.algorithm` - """ - self.algorithm = algorithm - # register the client's order method with the algorithm - self.algorithm.set_order(self.order) - # we need to provide the performance tracker with the - # sids referenced in the algorithm, so portfolio can - # initialize with all possible sids. - self.perf.set_sids(self.algorithm.get_sid_filter()) - def open(self): self.result_feed = self.connect_result() if self.results_socket: @@ -185,16 +179,6 @@ class TradeSimulationClient(Component): # LOG_EXTRA_FIELDS in zipline/protocol.py self.do_op(self.algorithm.handle_data, data) - def exception_callback(self, exc_type, exc_value, exc_traceback): - if self.results_socket: - log.info("Sending exception frame") - msg = zp.EXCEPTION_FRAME( - exc_traceback, - exc_type.__name__, - exc_value.message - ) - self.out_socket.send(msg) - def do_op(self, callable_op, *args, **kwargs): """ Wrap a callable operation with the zmq logbook handler if it exits.""" diff --git a/zipline/core/__init__.py b/zipline/core/__init__.py index d487dd05..a7d6b1f8 100644 --- a/zipline/core/__init__.py +++ b/zipline/core/__init__.py @@ -1,9 +1,9 @@ from host import ComponentHost from component import Component -from monitor import Controller +from monitor import Monitor __all__ = [ Component, - Controller, + Monitor, ComponentHost ] diff --git a/zipline/core/component.py b/zipline/core/component.py index fc859307..c9ce56e6 100644 --- a/zipline/core/component.py +++ b/zipline/core/component.py @@ -8,319 +8,225 @@ import uuid import time import socket import logbook -import traceback import humanhash +import multiprocessing from setproctitle import setproctitle +from collections import namedtuple + # pyzmq import zmq from zipline.core.monitor import PARAMETERS -from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_STATE, \ - COMPONENT_FAILURE, CONTROL_FRAME, CONTROL_UNFRAME +from zipline.protocol import ( + CONTROL_PROTOCOL, + COMPONENT_STATE, + CONTROL_FRAME, + CONTROL_UNFRAME, + EXCEPTION_FRAME +) + log = logbook.Logger('Component') -from zipline.exceptions import ComponentNoInit - class KillSignal(Exception): def __init__(self): pass +class ShutdownSignal(Exception): + def __init__(self): + pass + +ComponentSocketArgs = namedtuple('ComponentSocketArgs',['uri','style','bind']) + class Component(object): - """ - Base class for components. Defines the the base messaging - interface for components. - - :param addresses: a dict of name_string -> zmq port address strings. - Must have the following entries - - :param data_address: socket address used for data sources to stream - their records. Will be used in PUSH/PULL sockets - between data sources and a Feed. Bind will always - be on the PULL side (we always have N producers and - 1 consumer) - - :param feed_address: socket address used to publish consolidated feed - from serialization of data sources - will be used in PUB/SUB sockets between Feed and - Transforms. Bind is always on the PUB side. - - :param merge_address: socket address used to publish transformed - values. will be used in PUSH/PULL from many - transforms to one Merge Bind will always be on - the PULL side (we always have N producers and - 1 consumer) - - :param results_address: socket address used to publish merged data - source feed and transforms to clients will be - used in PUB/SUB from one Merge to one or many - clients. Bind is always on the PUB side. - - bind/connect methods will return the correct socket type for each - address. - - """ - # ------------ # Construction # ------------ - abstract = True - #__metaclass__ = WorkflowMeta + def __init__(self, + generator, + monitor, + socket_uri, + frame, + unframe, + component_id + ): - def __init__(self, *args, **kwargs): - self.zmq = None - self.context = None - self.addresses = None - self.waiting = None + # ----------------- + # Generator + # ----------------- + self.generator = generator + self.frame = frame + self.component_id = component_id + + # lock for waiting on monitor "GO" + self.waiting = None + + # ----------------- + # ZMQ properties + # ----------------- + self.in_socket_args = ComponentSocketArgs( + uri = socket_uri, + style = zmq.PULL, + bind = False + ) + self.out_socket_args = ComponentSocketArgs( + uri = socket_uri, + style = zmq.PUSH, + bind = True + ) + self.zmq = None + self.context = None + self.out_socket = None + self.in_socket = None + self.monitor = monitor + self.unframe = unframe + self.prefix = "" - self.out_socket = None - self.killed = False - self.controller = None - # timeout on heartbeat is very short to avoid burning - # cycles on heartbeating. unit is milliconds - self.heartbeat_timeout = 0 # TODO: state_flag is deprecated, remove - # TODO: error_state is deprecated, remove - self.state_flag = COMPONENT_STATE.OK - self.error_state = COMPONENT_FAILURE.NOFAILURE - self.on_done = None + self.state_flag = COMPONENT_STATE.OK - self._exception = None - self.fail_time = None - self.start_tic = None - self.stop_tic = None - self.note = None - self.confirmed = False - self.devel = False - self.socks = None - self.last_ping = None + # track time of last ping we received from monitor + self.last_ping = time.time() # Humanhashes make this way easier to debug because they stick # in your mind unlike a 32 byte string of random hex. - self.guid = uuid.uuid4() - self.huid = humanhash.humanize(self.guid.hex) + self.guid = uuid.uuid4() + self.huid = humanhash.humanize(self.guid.hex) - # This is where component specific constructors should be - # defined. Arguments passed to init are threaded through. - self.init(*args, **kwargs) + # first, start the generator in its own process. Once + # Monitor says "go", Events from the generator will be + # FRAME'd and PUSH'd to self.socket_uri. + monitor.add_to_topology(self.component_id) - def init(self): - """ - Subclasses should override this to extend the setup for the - class. Shouldn't have side effects. - """ - raise ComponentNoInit(self.__class__) + self.proc = multiprocessing.Process( + target=self.loop_send + ) + self.proc.start() + + # Placeholder for receive generator, which will be + # created in __iter__ + self.recv_gen = None # ------------ # Core Methods # ------------ - def open(self): - """ - Open the connections needed to start doing work. - """ - raise NotImplementedError - - def ready(self): - """ - Return ``True`` if and only if the component has finished - execution. - """ - return self.state_flag in [COMPONENT_STATE.DONE, \ - COMPONENT_STATE.EXCEPTION] - - def successful(self): - """ - Return ``True`` if and only if the component has finished - execution successfully, that is, without raising an error. - """ - return self.state_flag == COMPONENT_STATE.DONE and not \ - self.exception - - @property - def exception(self): - """ - Holds the exception that the component failed on, or ``None`` if - the component has not failed. - """ - return self._exception - - def do_work(self): - raise NotImplementedError - - def init_zmq(self): - self.zmq = zmq - self.context = self.zmq.Context() - self.zmq_poller = self.zmq.Poller - # The the process title so you can watch it in top - setproctitle(self.__class__.__name__) - return - - def _run(self): + def loop_send(self): """ The main component loop. This is wrapped inside a exception reporting context inside of run. The core logic of the all components is run here. """ - log.info("Start %r" % self) - log.info("Pid %s" % os.getpid()) - log.info("Group %s" % os.getpgrp()) - - self.start_tic = time.time() - - self.done = False # TODO: use state flag - self.sockets = [] - - self.init_zmq() - - self.setup_poller() - - self.setup_control() - self.open() - - self.signal_ready() - self.lock_ready() - self.wait_ready() - # ----------------------- - # YOU SHALL NOT PASS!!!!! - # ----------------------- - # ... until the controller signals GO - - self.loop() - - self.stop_tic = time.time() - - def run(self, catch_exceptions=True): - """ - Run the component. - """ try: - self._run() - except Exception as exc: - if not isinstance(exc, KillSignal): - self.signal_exception(exc) - else: - # if we get a kill signal, forcibly close all the - # sockets. - # exc_info = sys.exc_info() - # self.relay_exception(exc_info[0], exc_info[1], exc_info[2]) - self.teardown_sockets() + # The process title so you can watch it in top, ps. + self.prefix = "FORK-" + setproctitle(self.get_id) + log.info("Start %r" % self) + log.info("Pid %s" % os.getpid()) + log.info("Group %s" % os.getpgrp()) + + self.open() + + self.signal_ready() + self.lock_ready() + + msg = None + for event in self.generator: + + if hasattr(event, 'dt') and event.dt == 'DONE': + continue + + self.wait_ready() + + self.heartbeat() + msg = self.frame(event) + self.out_socket.send(msg) + + self.signal_done() + + # keep heartbeating until we receive the shutdown + # message from the Monitor (raises a + # ShutdownSignal), or we don't hear from the Monitor + # for MAX_COMPONENT_WAIT. + while True: + self.heartbeat(timeout=1000) + + except Exception as exc: + self.handle_exception(exc) finally: - self.shutdown() log.info("Exiting %r" % self) - def working(self): - """ - Controls when the work loop will start and end - If we encounter an exception or signal done exit. + def create_recv_gen(self): + try: + # return the generator + return self.loop_recv() + except Exception as exc: + self.handle_exception(exc) + finally: + log.info("Created Recv Gen for %r" % self) - Overload for higher order behavior. - """ - return (not self.done) + def loop_recv(self): + try: + self.open(send=False) + self.signal_ready() + self.lock_ready() - def loop(self, lockstep=True): - """ - Loop to do work while we still have work to do. - """ - while self.working(): + # we block on ready here until monitor sends the GO + # self.wait_ready() + for event in self.gen_from_poller(self.poll, self.in_socket, self.unframe): + yield event + + self.signal_done() + except Exception as exc: + self.handle_exception(exc) + finally: + log.info("Exiting %r" % self) + + def gen_from_poller(self, poller, in_socket, unframe): + + while True: + # Since we will yield None to avoid blocking, we need + # to have a small delay to give the poller a chance + # to receive a message from upstream. + socks = dict(poller.poll(100)) self.heartbeat() - self.do_work() + if socks.get(in_socket) == zmq.POLLIN: + message = in_socket.recv() + if message == str(CONTROL_PROTOCOL.DONE): + break + else: + event = unframe(message) + yield event + else: + yield - def runtime(self): - if self.ready() and self.start_tic and self.stop_tic: - return self.stop_tic - self.start_tic + def handle_exception(self, exc, re_raise=False): + if isinstance(exc, KillSignal): + # if we get a kill signal, forcibly close all the + # sockets. + self.teardown_sockets() + elif isinstance(exc, ShutdownSignal): + # signal from monitor of an orderly shutdown, + # do nothing. + pass + else: + self.signal_exception(exc) - def heartbeat(self, timeout=0): - # wait for synchronization reply from the host - self.socks = dict(self.poll.poll(timeout)) + def __iter__(self): + return self - # ---------------- - # Control Dispatch - # ---------------- - assert self.control_in, 'Component does not have a control_in socket' - - # If we're in devel mode drop out because the controller - # isn't guaranteed to be around anymore - if self.devel: - log.warn("Skipping heartbeat because of devel flag") - return - - if self.socks.get(self.control_in) == zmq.POLLIN: - msg = self.control_in.recv() - event, payload = CONTROL_UNFRAME(msg) - - # =========== - # Heartbeat - # =========== - - # The controller will send out a single number packed in - # a CONTROL_FRAME with ``heartbeat`` event every - # (n)-seconds. The component then has n seconds to - # respond to it. If not then it will be considered as - # malfunctioning or maybe CPU bound. - - if event == CONTROL_PROTOCOL.HEARTBEAT: - # Heart outgoing - heartbeat_frame = CONTROL_FRAME( - CONTROL_PROTOCOL.OK, - payload - ) - - self.last_ping = float(payload) - # Echo back the heartbeat identifier to tell the - # controller that this component is still alive and - # doing work - self.control_out.send(heartbeat_frame) - - - # ========= - # Soft Kill - # ========= - - # Try and clean up properly and send out any reports or - # data that are done during a clean shutdown. Inform the - # controller that we're done. - elif event == CONTROL_PROTOCOL.SHUTDOWN: - self.signal_done() - - # ========= - # Hard Kill - # ========= - - # Just exit. - elif event == CONTROL_PROTOCOL.KILL: - self.kill() - - # In case we didn't receive a ping, send a pre-emptive - # pong to the monitor. - elif hasattr(self, 'control_out') and \ - self.last_ping and \ - time.time() - self.last_ping > 1: - # send a ping ahead of schedule - pre_pong = time.time() - heartbeat_frame = CONTROL_FRAME( - CONTROL_PROTOCOL.OK, - str(pre_pong) - ) - - # Echo back the heartbeat identifier to tell the - # controller that this component is still alive and - # doing work - self.control_out.send(heartbeat_frame, self.zmq.NOBLOCK) - self.last_ping = pre_pong - elif self.last_ping and \ - time.time() - self.last_ping > PARAMETERS.MAX_COMPONENT_WAIT: - # monitor is gone without sending the shutdown - # signal, do a hard exit. - self.kill() + def next(self): + if not self.recv_gen: + self.recv_gen = self.create_recv_gen() + return self.recv_gen.next() # ---------------------------- # Cleanup & Modes of Failure @@ -339,12 +245,8 @@ class Component(object): def shutdown(self): """ Clean shutdown. - - Tear down after normal operation. """ - if self.on_done: - log.warn("{id} calling done.".format(id=self.get_id)) - self.on_done() + raise ShutdownSignal() def kill(self): """ @@ -353,9 +255,64 @@ class Component(object): Tear down ( fast ) as a mode of failure in the simulation or on service halt. """ - # sys.exit(1) raise KillSignal() + def signal_exception(self, exc=None, scope=None): + """ + All exceptions inside any component should boil back to + this handler. + + Will inform the system that the component has failed and how it + has failed. + """ + self.state_flag = COMPONENT_STATE.EXCEPTION + exc_type, exc_value, exc_traceback = sys.exc_info() + + # if a downstream component fails, this component may try + # sending when there are zero connections to the socket, + # which will raise ZMQError(EAGAIN). So, it doesn't make + # sense to relay this exception to Monitor and the rest + # of the zipline. + if isinstance(exc, zmq.ZMQError) and exc.errno == zmq.EAGAIN: + log.warn("{id} raised a ZMQError(EAGAIN) not relaying"\ + .format(id=self.get_id)) + return + + # sys.stdout.write(trace) + log.exception("Unexpected error in run for {id}.".format(id=self.get_id)) + + try: + log.info('{id} sending exception to monitor'\ + .format(id=self.get_id)) + msg = EXCEPTION_FRAME( + exc_traceback, + exc_type.__name__, + exc_value.message + ) + + exception_frame = CONTROL_FRAME( + CONTROL_PROTOCOL.EXCEPTION, + msg + ) + self.control_out.send(exception_frame, self.zmq.NOBLOCK) + # The monitor should relay the exception back + # to all zipline components. Wait here until the + # notice arrives, and we can assume other zipline + # components have broken out of their message + # loops. + for i in xrange(PARAMETERS.MAX_COMPONENT_WAIT): + self.heartbeat(timeout=1000) + log.warn("{id} never heard back from monitor."\ + .format(id=self.get_id)) + + except KillSignal: + log.info("{id} received confirmation from monitor"\ + .format(id=self.get_id)) + except: + log.exception("Exception waiting for monitor reply") + + + # ---------------------- # Internal Maintenance # ---------------------- @@ -399,7 +356,7 @@ class Component(object): # Go # ==== - # A distributed lock from the controller to ensure + # A distributed lock from the monitor to ensure # synchronized start. if event == CONTROL_PROTOCOL.HEARTBEAT: @@ -411,10 +368,10 @@ class Component(object): log.info('Prestart Heartbeat ' + self.get_id) elif event == CONTROL_PROTOCOL.GO: - # Side effectful call from the controller to unlock + # Side effectful call from the monitor to unlock # and begin doing work only when the entire topology # of the system beings to come online - log.info('Unlocking ' + self.__class__.__name__) + log.info('Unlocking ' + self.get_id) self.unlock_ready() # ========= @@ -423,9 +380,9 @@ class Component(object): # Try and clean up properly and send out any reports or # data that are done during a clean shutdown. Inform the - # controller that we're done. + # monitor that we're done. elif event == CONTROL_PROTOCOL.SHUTDOWN: - self.signal_done() + self.shutdown() break # ========= @@ -439,101 +396,93 @@ class Component(object): elif time.time() - start_wait > PARAMETERS.MAX_COMPONENT_WAIT: log.info('No go signal from monitor, %s exiting' \ - % self.__class__.__name__) + % self.get_id) self.kill() break + def heartbeat(self, timeout=0): + # wait for synchronization reply from the host + socks = dict(self.poll.poll(timeout)) + + # ---------------- + # Control Dispatch + # ---------------- + assert self.control_in, 'Component does not have a control_in socket' + + if socks.get(self.control_in) == zmq.POLLIN: + msg = self.control_in.recv() + event, payload = CONTROL_UNFRAME(msg) + + # =========== + # Heartbeat + # =========== + + # The monitor will send out a single number packed in + # a CONTROL_FRAME with ``heartbeat`` event every + # (n)-seconds. The component then has n seconds to + # respond to it. If not then it will be considered as + # malfunctioning or maybe CPU bound. + + if event == CONTROL_PROTOCOL.HEARTBEAT: + # Heart outgoing + heartbeat_frame = CONTROL_FRAME( + CONTROL_PROTOCOL.OK, + payload + ) + + self.last_ping = float(payload) + # Echo back the heartbeat identifier to tell the + # monitor that this component is still alive and + # doing work + self.control_out.send(heartbeat_frame) + + + # ========= + # Soft Kill + # ========= + + # Try and clean up properly and send out any reports or + # data that are done during a clean shutdown. Inform the + # monitor that we're done. + elif event == CONTROL_PROTOCOL.SHUTDOWN: + self.shutdown() + + # ========= + # Hard Kill + # ========= + + # Just exit. + elif event == CONTROL_PROTOCOL.KILL: + self.kill() + + # In case we didn't receive a ping, send a pre-emptive + # pong to the monitor. + elif time.time() - self.last_ping > 2: + # send a ping ahead of schedule + pre_pong = time.time() + heartbeat_frame = CONTROL_FRAME( + CONTROL_PROTOCOL.OK, + str(pre_pong) + ) + + # Echo back the heartbeat identifier to tell the + # monitor that this component is still alive and + # doing work + self.control_out.send(heartbeat_frame, self.zmq.NOBLOCK) + self.last_ping = pre_pong + elif time.time() - self.last_ping > PARAMETERS.MAX_COMPONENT_WAIT: + # monitor is gone without sending the shutdown + # signal, do a hard exit. + self.kill() + def signal_ready(self): - log.info(self.__class__.__name__ + ' is ready') - - if hasattr(self, 'control_out'): - frame = CONTROL_FRAME( - CONTROL_PROTOCOL.READY, - '' - ) - self.control_out.send(frame) - - def signal_cancel(self): - self.done = True - - # TODO: no hasattr hacks - #if not self.controller: - if hasattr(self, 'control_out'): - frame = CONTROL_FRAME( - CONTROL_PROTOCOL.SHUTDOWN, - None - ) - self.control_out.send(frame) - - # then proceeds to do shutdown(), and teardown_sockets() - # to complete the process - - def signal_exception(self, exc=None, scope=None): - """ - All exceptions inside any component should boil back to - this handler. - - Will inform the system that the component has failed and how it - has failed. - """ - - if scope == 'algo': - self.error_state = COMPONENT_FAILURE.ALGOEXCEPT - else: - self.error_state = COMPONENT_FAILURE.HOSTEXCEPT - - self.state_flag = COMPONENT_STATE.EXCEPTION - # mark the time of failure so we can track the failure - # progogation through the system. - - self.stop_tic = time.time() - - self._exception = exc - exc_type, exc_value, exc_traceback = sys.exc_info() - trace = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) - - # if a downstream component fails, this component may try - # sending when there are zero connections to the socket, - # which will raise ZMQError(EAGAIN). So, it doesn't make - # sense to relay this exception to Monitor and the rest - # of the zipline. - if isinstance(exc, zmq.ZMQError) and exc.errno == zmq.EAGAIN: - log.warn("{id} raised a ZMQError(EAGAIN) not relaying"\ - .format(id=self.get_id)) - return - - # sys.stdout.write(trace) - log.exception("Unexpected error in run for {id}.".format(id=self.get_id)) - - self.relay_exception(exc_type, exc_value, exc_traceback) - - if hasattr(self, 'control_out') and self.control_out: - try: - log.info('{id} sending exception to controller'.format(id=self.get_id)) - exception_frame = CONTROL_FRAME( - CONTROL_PROTOCOL.EXCEPTION, - trace - ) - self.control_out.send(exception_frame, self.zmq.NOBLOCK) - # The controller should relay the exception back - # to all zipline components. Wait here until the - # notice arrives, and we can assume other zipline - # components have broken out of their message - # loops. - for i in xrange(PARAMETERS.MAX_COMPONENT_WAIT): - self.heartbeat(timeout=1000) - log.warn("{id} never heard back from monitor."\ - .format(id=self.get_id)) - except: - log.exception("Exception waiting for controller reply") - - def relay_exception(self, exc_type, exc_value, exc_traceback): - if hasattr(self, 'exception_callback') and self.exception_callback: - log.info('{id} making exception callback'.format(id=self.get_id)) - self.exception_callback(exc_type, exc_value, exc_traceback) - - + log.info(self.get_id + ' is ready') + frame = CONTROL_FRAME( + CONTROL_PROTOCOL.READY, + '' + ) + self.control_out.send(frame) def signal_done(self): """ @@ -544,69 +493,72 @@ class Component(object): # notify internal work loop that we're done self.done = True # TODO: use state flag - if hasattr(self, 'out_socket') and self.out_socket: + if self.out_socket: msg = zmq.Message(str(CONTROL_PROTOCOL.DONE)) self.out_socket.send(msg) - if hasattr(self, 'control_out'): - # notify controller we're done - done_frame = CONTROL_FRAME( - CONTROL_PROTOCOL.DONE, - '' - ) - - self.control_out.send(done_frame) - log.info("[%s] sent control done" % self.get_id) - - # there is a narrow race condition where we finish just - # after the Monitor accepts our prior heartbeat, but just - # before the next one is sent. So, we hang around for one - # last heartbeat, and wait an unusually long time. - self.heartbeat(timeout=5000) - - + # notify monitor we're done + done_frame = CONTROL_FRAME( + CONTROL_PROTOCOL.DONE, + '' + ) + self.control_out.send(done_frame) + log.info("[%s] sent control done" % self.get_id) # ----------- # Messaging # ----------- - def setup_poller(self): + def open(self, send=True): """ - Setup the poller used for multiplexing the incoming data - handling sockets. + Open the connections needed to start doing work. + Perform any setup that must be done within process. """ - self.poll = self.zmq_poller() + self.sockets = [] + self.zmq = zmq + self.context = self.zmq.Context() + self.poll = self.zmq.Poller() - def bind_data(self): - return self.bind_pull_socket(self.addresses['data_address']) + self.setup_control() - def connect_data(self): - return self.connect_push_socket(self.addresses['data_address']) + if send: + self.out_socket = self.open_socket(self.out_socket_args) + self.sockets.extend([self.out_socket]) + else: + self.in_socket = self.open_socket(self.in_socket_args) + self.sockets.extend([self.in_socket]) - def bind_feed(self): - return self.bind_pub_socket(self.addresses['feed_address']) + def open_socket(self, sock_args): + if sock_args.bind: + return self.bind_socket(sock_args) + else: + return self.connect_socket(sock_args) - def connect_feed(self): - return self.connect_sub_socket(self.addresses['feed_address']) + def bind_socket(self, sock_args): + if sock_args.style == zmq.PULL: + return self.bind_pull_socket(sock_args.uri) + if sock_args.style == zmq.PUSH: + return self.bind_push_socket(sock_args.uri) + if sock_args.style == zmq.PUB: + return self.bind_pub_socket(sock_args.uri) - def bind_merge(self): - return self.bind_pull_socket(self.addresses['merge_address']) + raise Exception("Invalid socket arguments") - def connect_merge(self): - return self.connect_push_socket(self.addresses['merge_address']) + def connect_socket(self, sock_args): + if sock_args.style == zmq.PULL: + return self.connect_pull_socket(sock_args.uri) + if sock_args.style == zmq.PUSH: + return self.connect_push_socket(sock_args.uri) + if sock_args.style == zmq.SUB: + return self.connect_sub_socket(sock_args.uri) - def bind_result(self): - return self.bind_push_socket(self.addresses['results_address']) - - def connect_result(self): - return self.connect_pull_socket(self.addresses['results_address']) + raise Exception("Invalid socket arguments") def bind_push_socket(self, addr): push_socket = self.context.socket(self.zmq.PUSH) push_socket.bind(addr) - self.out_socket = push_socket self.sockets.append(push_socket) return push_socket @@ -624,7 +576,6 @@ class Component(object): pull_socket = self.context.socket(self.zmq.PULL) pull_socket.bind(addr) self.poll.register(pull_socket, self.zmq.POLLIN) - self.sockets.append(pull_socket) return pull_socket @@ -632,29 +583,10 @@ class Component(object): def connect_push_socket(self, addr): push_socket = self.context.socket(self.zmq.PUSH) push_socket.connect(addr) - #push_socket.setsockopt(self.zmq.LINGER,0) self.sockets.append(push_socket) - self.out_socket = push_socket return push_socket - def bind_pub_socket(self, addr): - pub_socket = self.context.socket(self.zmq.PUB) - pub_socket.bind(addr) - #pub_socket.setsockopt(self.zmq.LINGER, 0) - self.out_socket = pub_socket - - return pub_socket - - def connect_sub_socket(self, addr): - sub_socket = self.context.socket(self.zmq.SUB) - sub_socket.connect(addr) - sub_socket.setsockopt(self.zmq.SUBSCRIBE,'') - self.sockets.append(sub_socket) - - self.poll.register(sub_socket, self.zmq.POLLIN) - - return sub_socket def setup_control(self): """ @@ -662,92 +594,32 @@ class Component(object): of the simulation and to forcefully tear down the simulation in case of a failure. """ - - # Allow for the possibility of not having a controller, - # possibly the zipline devsimulator may not want this. - if not self.controller: - return - - self.control_out = self.controller.message_sender( + self.control_out = self.monitor.message_sender( identity = self.get_id, context = self.context, ) - self.control_in = self.controller.message_listener( + self.control_in = self.monitor.message_listener( context = self.context ) self.poll.register(self.control_in, self.zmq.POLLIN) self.sockets.extend([self.control_in, self.control_out]) - # ----------- - # FSM Actions - # ----------- - - #@property - #def state(self): - #if not hasattr(self, '_state'): - #self._state = self.initial_state - #else: - #return self._state - - #@state.setter - #def state(self, new): - #if not hasattr(self, '_state'): - #self._state = self.initial_state - - #old = self._state - - #if (old, new) in self.workflow: - #self._state = new - #else: - #raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new)) - # --------------------- # Description and Debug # --------------------- - def extern_logger(self): - """ - Pipe logs out to a provided logging interface. - """ - pass - - def setup_extern_logger(self): - """ - Pipe logs out to a provided logging interface. - """ - pass - @property def get_id(self): """ - The descriptive name of the component. + The time invariant name for this component. + Must be unique within this zipline. """ - # Prevents the bug that Thomas ran into - raise NotImplementedError + return self.prefix + self.component_id - @property - def get_type(self): - """ - The data flow type of the component. - - - ``SOURCE`` - - ``CONDUIT`` - - ``SINK`` - - """ - raise NotImplementedError - - @property - def get_pure(self): - """ - Describes whehter this component purely functional, i.e. for a - given set of inputs is it guaranteed to always give the same - output . Components that are side-effectful are, generally, not - pure. - """ - return False + def get_hash(self): + return self.component_id def debug(self): """ @@ -760,18 +632,12 @@ class Component(object): 'pid' : os.getpid() , 'memaddress' : hex(id(self)) , 'ready' : self.successful() , - 'succesfull' : self.ready() , + 'successful' : self.ready() , } - def __len__(self): - """ - Some components overload this for debug purposes - """ - raise NotImplementedError - def __repr__(self): """ - Return a usefull string representation of the component to + Return a useful string representation of the component to indicate its type, unique identifier, and computational context identifier name. """ diff --git a/zipline/core/host.py b/zipline/core/host.py index ea1ca0aa..37de82aa 100644 --- a/zipline/core/host.py +++ b/zipline/core/host.py @@ -103,7 +103,7 @@ class ComponentHost(object): log.info('== Roll Call ==') - log.info('Controller') + log.info('Monitor') self.launch_controller() diff --git a/zipline/core/monitor.py b/zipline/core/monitor.py index 50c036aa..7a920742 100644 --- a/zipline/core/monitor.py +++ b/zipline/core/monitor.py @@ -1,4 +1,3 @@ -import inspect import os import zmq import sys @@ -10,8 +9,13 @@ from signal import SIGHUP, SIGINT from collections import OrderedDict, Counter -from zipline.protocol import CONTROL_PROTOCOL, CONTROL_FRAME, \ - CONTROL_UNFRAME, CONTROL_STATES, INVALID_CONTROL_FRAME \ +from zipline.protocol import ( + CONTROL_PROTOCOL, + CONTROL_FRAME, + CONTROL_UNFRAME, + CONTROL_STATES, + INVALID_CONTROL_FRAME +) from zipline.utils.protocol_utils import ndict @@ -34,7 +38,7 @@ class UnknownChatter(Exception): return """Component calling itself "%s" talking on unexpected channel""" % self.named -log = logbook.Logger('Controller') +log = logbook.Logger('Monitor') # The scalars determining the timing of the monitor behavior for # the system. @@ -52,7 +56,7 @@ PARAMETERS = ndict(dict( SYSTEM_TIMEOUT = 50, )) -class Controller(object): +class Monitor(object): """ A N to M messaging system for inter component communication. @@ -69,7 +73,12 @@ class Controller(object): debug = True period = PARAMETERS.GENERATIONAL_PERIOD - def __init__(self, pub_socket, route_socket, send_sighup=False): + def __init__( + self, + pub_socket, + route_socket, + exception_socket, + send_sighup=False): self.nosignals = False self.context = None @@ -90,13 +99,15 @@ class Controller(object): self.associated = [] - self.pub_socket = pub_socket - self.route_socket = route_socket - - self.error_replay = OrderedDict() + self.pub_socket = pub_socket + self.route_socket = route_socket + self.exception_socket = exception_socket self.missed_beats = Counter() + # start with an empty topology + self.topology = set([]) + self.send_sighup = send_sighup if self.send_sighup: log.info("Request to send sighup/sigint") @@ -108,6 +119,17 @@ class Controller(object): self.zmq_poller = self.zmq.Poller return + def add_to_topology(self, component_id): + add = set([component_id, "FORK-" + component_id]) + self.topology.update(add) + + def freeze_topology(self): + if isinstance(self.topology, frozenset): + return + # we've been incrementally adding components. + # time to freeze. + self.manage(self.topology) + def manage(self, topology): """ Give the controller a set set of components to manage and @@ -139,6 +161,7 @@ class Controller(object): raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new)) def run(self): + self.freeze_topology() self.running = True self.init_zmq() setproctitle('Monitor') @@ -243,6 +266,11 @@ class Controller(object): self.router.bind(self.route_socket) self.router.setsockopt(zmq.LINGER, 0) + # -- Exception Out -- + # =================== + self.ex_out = self.context.socket(self.zmq.PUSH) + self.ex_out.connect(self.exception_socket) + poller = self.zmq.Poller() poller.register(self.router, self.zmq.POLLIN) #poller.register(self.cancel, self.zmq.POLLIN) @@ -312,6 +340,11 @@ class Controller(object): log.info("breaking out of initial heartbeat") break + # Break out if the entire topology told us its DONE + if len(self.finished) == len(self.topology): + break + + # ================ # Heartbeat Stats # ================ @@ -374,6 +407,7 @@ class Controller(object): """ if not self.send_sighup: log.warning("Skipping SIGINT") + return ppid = os.getpid() log.warning("Sending SIGINT") os.kill(ppid, SIGINT) @@ -403,28 +437,22 @@ class Controller(object): bad = self.tracked - good - self.finished new = self.responses - good - self.finished - missing = self.topology - self.tracked - self.finished - for component in new: self.new(component) - if self.debug: - log.info('New component %r' % component) - for component in bad: - self.fail(component) + self.timed_out(component) + missing = self.topology - self.tracked - self.finished for component in missing: - if self.debug: log.info('Missing component %r' % component) - if self.debug: - for component in self.tracked: - if component not in self.topology: - log.info('Uninvited component %r' % component) + for component in self.tracked: + if component not in self.topology: + log.info('Uninvited component %r' % component) # -------------- # Init Handlers @@ -460,22 +488,16 @@ class Controller(object): # Epic Fail Handling # ------------------ - def fail_universal(self): - # TODO: this requires higher order functionality - log.error('Component missed heartbeat, Monitor shutting down system') - self.kill() - - def fail(self, component): + def timed_out(self, component): if self.state is CONTROL_STATES.TERMINATE: return - universal = self.fail_universal - fail_handlers = { } - if component in (self.topology - self.finished) or self.freeform: log.warning('Component "%s" missed heartbeat' % component) - self.tracked.remove(component) - fail_handlers.get(component, universal)() + # we treat a time out as a severe failure, and + # conduct a rapid shutdown + self.kill() + # ------------------- # Completion Handling @@ -489,26 +511,15 @@ class Controller(object): # -------------- # Error Handling # -------------- - - def exception_universal(self): - """ - Shutdown the system on failure. - """ - log.error('System in exception state, shutting down') + def exception(self, component, exception_data): + log.error('Component in exception state: %s. Shutting down system and sending exception data to listeners.'\ + % component) + # Send the exception message out to listeners. + self.ex_out.send(exception_data) + # An exception in one component is treated as a hard + # failure, and we conduct a rapid shutdown. self.kill() - def exception(self, component, failure): - universal = self.exception_universal - exception_handlers = { } - - if component in self.topology or self.freeform: - self.error_replay[(component, time.time())] = failure - log.error('Component in exception state: %s' % component) - - exception_handlers.get(component, universal)() - else: - raise UnknownChatter(component) - # ----------------- # Protocol Handling # ----------------- @@ -555,7 +566,19 @@ class Controller(object): # A component is telling us it failed, and how if id is CONTROL_PROTOCOL.EXCEPTION: - self.exception(identity, status) + # status should be a msgpack emitted from + # EXCEPTION_FRAME + try: + exception_data = status + self.exception(identity, exception_data) + except: + # if an exception occurs when we try to handle + # the exception, signal the parent that we need + # to go down + # TODO: should we attempt to call self.exception? + log.exception("Unexpected exception sending exception data") + self.kill() + return # A component is telling us its done with work and won't @@ -605,11 +628,9 @@ class Controller(object): self.associated.append(s) return s - def do_error_replay(self): - for (component, time), error in self.error_replay.iteritems(): - log.info('Component Log for -- %s --:\n%s' % (component, error)) - def kill(self): + """Aggressively exit the whole zipline. + """ if self.state is CONTROL_STATES.TERMINATE: return @@ -617,6 +638,9 @@ class Controller(object): self.send_hardkill() self.state = CONTROL_STATES.TERMINATE self.alive = False + # send burrito an interrupt, instructing it to kill all + # child processes assocated with this zipline. + time.sleep(3) self.signal_interrupt() def shutdown(self): diff --git a/zipline/core/process.py b/zipline/core/process.py index b2a01429..dc1fcd3c 100644 --- a/zipline/core/process.py +++ b/zipline/core/process.py @@ -40,13 +40,13 @@ class ProcessSimulator(ComponentHost): # invoked by the host's open() def launch_controller(self): - proc = multiprocessing.Process(target=self.controller.run) + proc = multiprocessing.Process(target=self.monitor.run) proc.start() self.con = proc # Process specific - self.controller_process = proc - self.mapping[proc.pid] = 'Controller' + self.monitor_process = proc + self.mapping[proc.pid] = 'Monitor' def launch_component(self, component): proc = multiprocessing.Process(target=component.run) @@ -81,7 +81,7 @@ class ProcessSimulator(ComponentHost): process.join(timeout=1) process.terminate() - self.controller.shutdown(soft=True) + self.monitor.shutdown(soft=True) self.running = False self.con.terminate() diff --git a/zipline/finance/performance.py b/zipline/finance/performance.py index fdca878d..4c96ecce 100644 --- a/zipline/finance/performance.py +++ b/zipline/finance/performance.py @@ -133,6 +133,7 @@ import zipline.finance.risk as risk log = logbook.Logger('Performance') class PerformanceTracker(object): + UPDATER = True """ Tracks the performance of the zipline as it is running in the simulator, relays this out to the Deluge broker and then @@ -144,7 +145,7 @@ class PerformanceTracker(object): """ - def __init__(self, trading_environment): + def __init__(self, trading_environment, sid_list): self.trading_environment = trading_environment self.trading_day = datetime.timedelta(hours = 6, minutes = 30) @@ -164,7 +165,6 @@ class PerformanceTracker(object): self.txn_count = 0 self.event_count = 0 self.last_dict = None - self.order_log = [] self.exceeded_max_loss = False self.results_socket = None @@ -197,10 +197,21 @@ class PerformanceTracker(object): # save the transactions for the daily periods keep_transactions = True ) - - def set_sids(self, sid_list): + for sid in sid_list: self.cumulative_performance.positions[sid] = Position(sid) + self.todays_performance.positions[sid] = Position(sid) + + def update(self, event): + if event.dt == "DONE": + event.perf_message = self.handle_simulation_end() + del event['TRANSACTION'] + return event + else: + event.perf_message = self.process_event(event) + event.portfolio = self.get_portfolio() + del event['TRANSACTION'] + return event def get_portfolio(self): return self.cumulative_performance.as_portfolio() @@ -238,10 +249,10 @@ class PerformanceTracker(object): 'cumulative_risk_metrics' : self.cumulative_risk_metrics.to_dict() } - def log_order(self, order): - self.order_log.append(order) def process_event(self, event): + + message = None if self.exceeded_max_loss: return @@ -250,7 +261,7 @@ class PerformanceTracker(object): self.event_count += 1 if(event.dt >= self.market_close): - self.handle_market_close() + message = self.handle_market_close() if event.TRANSACTION: self.txn_count += 1 @@ -264,9 +275,12 @@ class PerformanceTracker(object): #calculate performance as of last trade self.cumulative_performance.calculate_performance() self.todays_performance.calculate_performance() + + + return message def handle_market_close(self): - + # add the return results from today to the list of DailyReturn objects. todays_date = self.market_close.replace(hour=0, minute=0, second=0) todays_return_obj = risk.DailyReturn( @@ -288,12 +302,10 @@ class PerformanceTracker(object): # calculate progress of test self.progress = self.day_count / self.total_days - # Output results - if self.results_socket: - msg = zp.PERF_FRAME(self.to_dict()) - self.results_socket.send(msg) + # Take a snapshot of our current peformance to return to the + # browser. + daily_update = self.to_dict() - # if self.trading_environment.max_drawdown: returns = self.todays_performance.returns max_dd = -1 * self.trading_environment.max_drawdown @@ -304,7 +316,7 @@ class PerformanceTracker(object): # so it shows up in the update, but don't end the test # here. Let the update go out before stopping self.exceeded_max_loss = True - return + return daily_update #move the market day markers forward @@ -326,6 +338,8 @@ class PerformanceTracker(object): self.market_close, keep_transactions = True ) + + return daily_update def handle_simulation_end(self): """ @@ -349,12 +363,8 @@ class PerformanceTracker(object): exceeded_max_loss = self.exceeded_max_loss ) - if self.results_socket: - log.info("about to stream the risk report...") - risk_dict = self.risk_report.to_dict() - - msg = zp.RISK_FRAME(risk_dict) - self.results_socket.send(msg) + risk_dict = self.risk_report.to_dict() + return risk_dict class Position(object): @@ -362,8 +372,8 @@ class Position(object): self.sid = sid self.amount = 0 self.cost_basis = 0.0 ##per share - self.last_sale_price = None - self.last_sale_date = None + self.last_sale_price = 0.0 + self.last_sale_date = 0.0 def update(self, txn): if(self.sid != txn.sid): @@ -584,7 +594,6 @@ class PerformancePeriod(object): return positions - # def get_positions_list(self): positions = [] for sid, pos in self.positions.iteritems(): diff --git a/zipline/finance/returns.py b/zipline/finance/returns.py deleted file mode 100644 index 5585f325..00000000 --- a/zipline/finance/returns.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections import defaultdict -from zipline.transforms.base import BaseTransform - -class ReturnsTransform(BaseTransform): - - def init(self, name): - self.state = {} - self.state['name'] = name - self.by_sid = defaultdict(self._create) - - @property - def get_id(self): - return self.state['name'] - - - def transform(self, event): - cur = self.by_sid[event.sid] - cur.update(event) - self.state['value'] = cur.returns - return self.state - - def _create(self): - return ReturnsFromPriorClose() - -class ReturnsFromPriorClose(object): - """ - Calculates a security's returns since the previous close, using the - current price. - """ - - def __init__(self): - self.last_close = None - self.last_event = None - self.returns = 0.0 - - def update(self, event): - if self.last_close: - change = event.price - self.last_close.price - self.returns = change / self.last_close.price - - if self.last_event: - if self.last_event.dt.day != event.dt.day: - # the current event is from the day after - # the last event. Therefore the last event was - # the last close - self.last_close = self.last_event - - # the current event is now the last_event - self.last_event = event diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index bf3a5374..1d82bc66 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -9,10 +9,10 @@ from zipline.protocol import SIMULATION_STYLE log = logbook.Logger('Transaction Simulator') class TransactionSimulator(object): + UPDATER = True - def __init__(self, style=SIMULATION_STYLE.PARTIAL_VOLUME): + def __init__(self, sid_filter, style=SIMULATION_STYLE.PARTIAL_VOLUME): self.open_orders = {} - self.order_count = 0 self.txn_count = 0 self.trade_window = datetime.timedelta(seconds=30) self.orderTTL = datetime.timedelta(days=1) @@ -27,27 +27,20 @@ class TransactionSimulator(object): elif style == SIMULATION_STYLE.NOOP: self.apply_trade_to_open_orders = self.simulate_noop - def add_open_order(self, event): - # Orders are captured in a buffer by sid. No calculations are done here. - # Amount is explicitly converted to an int. - # Orders of amount zero are ignored. + for sid in sid_filter: + self.open_orders[sid] = [] - self.order_count += 1 - event.amount = int(event.amount) + def place_order(self, order): + # initialized filled field. + order.filled = 0 + self.open_orders[order.sid].append(order) - if event.amount == 0: - log = "requested to trade zero shares of {sid}".format( - sid=event.sid - ) - log.debug(log) - return - - if not self.open_orders.has_key(event.sid): - self.open_orders[event.sid] = [] - - # set the filled property to zero - event.filled = 0 - self.open_orders[event.sid].append(event) + def update(self, event): + event.TRANSACTION = None + # We only fill transactions on trade events. + if event.type == zp.DATASOURCE_TYPE.TRADE: + event.TRANSACTION = self.apply_trade_to_open_orders(event) + return event def simulate_buy_all(self, event): txn = self.create_transaction( @@ -81,7 +74,7 @@ class TransactionSimulator(object): txn = self.create_transaction( event.sid, amount, - event.price + 0.10, + event.price + 0.10, # Magic constant? event.dt, direction ) @@ -162,7 +155,6 @@ class TransactionSimulator(object): } return zp.ndict(txn) - class TradingEnvironment(object): def __init__( diff --git a/zipline/gens/composites.py b/zipline/gens/composites.py index 832909ad..b3fa7576 100644 --- a/zipline/gens/composites.py +++ b/zipline/gens/composites.py @@ -1,91 +1,101 @@ import datetime -from itertools import tee, starmap +from itertools import tee, starmap, chain from collections import namedtuple from zipline.gens.tradegens import SpecificEquityTrades -from zipline.gens.utils import roundrobin, hash_args +from zipline.gens.utils import roundrobin, hash_args, done_message from zipline.gens.sort import date_sort from zipline.gens.merge import merge -from zipline.gens.transform import stateful_transform +from zipline.gens.transform import StatefulTransform -SortBundle = namedtuple("SortBundle", ['source', 'args', 'kwargs']) -MergeBundle = namedtuple("MergeBundle", ['stream', 'tnfm', 'args', 'kwargs']) - -def date_sorted_sources(sources, source_args, source_kwargs): +def date_sorted_sources(*sources): """ - Takes a list of generator functions, a list of tuples of positional arguments, - and a list of dictionaries of keyword arguments. Packages up all arguments - and passes them into a date_sort. + Takes an iterable of sources, generating namestrings and + piping their output into date_sort. """ - assert len(sources) == len(source_args) == len(source_kwargs) - # Package up sources and arguments. - - # Create a generator of SortBundle objects to be turned into - # namestrings and generator objects. - bundle_gen = starmap(SortBundle, zip(sources, source_args, source_kwargs)) - # Load the results of the generator into a tuple so that the - # results can be used twice (once in namestring comprehension, - # once in the generator comprehension for intialized sources. - bundles = tuple(bundle_gen) + for source in sources: + assert iter(source), "Source %s not iterable" % source + assert source.__class__.__dict__.has_key('get_hash'), "No get_hash" - # Calculate namestring hashes to pass to date_sort. - names = [bundle.source.__name__ + hash_args(*bundle.args, **bundle.kwargs) - for bundle in bundles] - # Pass each source its arguments. - initialized = [bundle.source(*bundle.args, **bundle.kwargs) - for bundle in bundles] + # Get name hashes to pass to date_sort. + names = [source.get_hash() for source in sources] # Convert the list of generators into a flat stream by pulling # one element at a time from each. - stream_in = roundrobin(*initialized) - - # Guarantee the flat stream will be sorted by date, using source_id as - # tie-breaker, which is fully deterministic (given deterministic string - # representation for all args/kwargs) + stream_in = roundrobin(sources, names) + + # Guarantee the flat stream will be sorted by date, using + # source_id as tie-breaker, which is fully deterministic (given + # deterministic string representation for all args/kwargs) + return date_sort(stream_in, names) - -def merged_transforms(sorted_stream, tnfms, tnfm_args, tnfm_kwargs): +def merged_transforms(sorted_stream, *transforms): """ - A generator that takes the expected output of a date_sort, pipes it - through a given set of transforms, and runs the results throught a - merge to output a unified stream. tnfms should be a list of - pointers to generator functions. tnfm_args should be a list of - tuples, representing the arguments to be passed to each transform. - tnfm_kwargs should be a list of dictionaries representing keyword - arguments to each transform. + A generator that takes the expected output of a date_sort, pipes + it through a given set of transforms, and runs the results + through a merge to output a unified stream. tnfms should be a + list of pointers to generator functions. tnfm_args should be a + list of tuples, representing the arguments to be passed to each + transform. tnfm_kwargs should be a list of dictionaries + representing keyword arguments to each transform. """ - - # We should have as many sets of args as we have transforms. - assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs) - - # Create a copy of the stream for each transform. - split = tee(sorted_stream, len(tnfms)) - - # Package each transform with a stream copy and set of args. Use a list - # so that we can re-use this for calculating hashes. - bundle_gen = starmap(MergeBundle, zip(split, tnfms, tnfm_args, tnfm_kwargs)) - - bundles = tuple(bundle_gen) - # list comprehension to create transform generators from - # bundles - tnfm_gens = [ - stateful_transform( - bundle.stream, - bundle.tnfm, - *bundle.args, - **bundle.kwargs - ) - for bundle in bundles] + for transform in transforms: + assert isinstance(transform, StatefulTransform) # Generate expected hashes for each transform - hashes = [bundle.tnfm.__name__ + hash_args(*bundle.args, **bundle.kwargs) - for bundle in bundles] + namestrings = [tnfm.get_hash() for tnfm in transforms] - # Roundrobin the outputs of our transforms to create a single flat stream. - to_merge = roundrobin(*tnfm_gens) + # Create a copy of the stream for each transform. + split = tee(sorted_stream, len(transforms)) + + # Package a stream copy with each StatefulTransform instance. + bundles = zip(transforms, split) + + # Convert the copies into transform streams. + tnfm_gens = [tnfm.transform(stream) for tnfm, stream in bundles] + + # Roundrobin the outputs of our transforms to create a single flat + # stream. + to_merge = roundrobin(tnfm_gens, namestrings) # Pipe the stream into merge. - merged = merge(to_merge, hashes) - return merged_transforms + merged = merge(to_merge, namestrings) + + dt_aliased = alias_dt(merged) + # Return the merged events. + return add_done(dt_aliased) + +def sequential_transforms(stream_in, *transforms): + """ + Apply each transform in transforms sequentially to each event in stream_in. + Each transform application will add a new entry indexed to the transform's + hash string. + """ + + assert isinstance(transforms, (list, tuple)) + for tnfm in transforms: + tnfm.forward_all = False + tnfm.update_in_place = False + tnfm.append_value = True + + # Recursively apply all transforms to the stream. + stream_out = reduce(lambda stream, tnfm: tnfm.transform(stream), + transforms, + stream_in) + + dt_aliased = alias_dt(stream_out) + return add_done(dt_aliased) + +def alias_dt(stream_in): + """ + Alias the dt field to datetime on each message. + """ + for message in stream_in: + message['datetime'] = message['dt'] + yield message + +# Add a done message to a stream. +def add_done(stream_in): + return chain(stream_in, [done_message('Composite')]) diff --git a/zipline/gens/examples.py b/zipline/gens/examples.py new file mode 100644 index 00000000..f3a0dd0b --- /dev/null +++ b/zipline/gens/examples.py @@ -0,0 +1,100 @@ +import pytz +import time + +from time import sleep +from pprint import pprint as pp +from datetime import datetime, timedelta +from itertools import izip + +from zipline.utils.factory import create_trading_environment +from zipline.test_algorithms import TestAlgorithm + +from zipline.gens.composites import SourceBundle, TransformBundle, \ + date_sorted_sources, merged_transforms, sequential_transforms +from zipline.gens.tradegens import SpecificEquityTrades +from zipline.gens.transform import MovingAverage, Passthrough, StatefulTransform +from zipline.gens.tradesimulation import TradeSimulationClient as tsc + +import zipline.protocol as zp + +if __name__ == "__main__": + + filter = [2,3] + #Set up source a. Six minutes between events. + args_a = tuple() + kwargs_a = { + 'count' : 1000, + 'sids' : [1,2,3], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 6), + 'filter' : filter + } + source_a = SpecificEquityTrades(*args_a, **kwargs_a) + source_a_prime = SpecificEquityTrades(*args_a, **kwargs_a) + + #Set up source b. Five minutes between events. + args_b = tuple() + kwargs_b = { + 'count' : 1000, + 'sids' : [2,3,4], + 'start' : datetime(2012,1,3,14, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 5), + 'filter' : filter + } + source_b = SpecificEquityTrades(*args_b, **kwargs_b) + source_b_prime = SpecificEquityTrades(*args_b, **kwargs_b) + + sorted = date_sorted_sources(source_a, source_b) + sorted_prime = date_sorted_sources( + source_a_prime, + source_b_prime + ) + + passthrough = StatefulTransform(Passthrough) + mavg_price = StatefulTransform( + MovingAverage, + timedelta(minutes = 20), + ['price'] + ) + + passthrough_prime = StatefulTransform(Passthrough) + mavg_price_prime = StatefulTransform( + MovingAverage, + timedelta(minutes = 20), + ['price'] + ) + + merged = merged_transforms(sorted, passthrough, mavg_price) + start = time.time() + for message in merged: + assert 1 + 1 == 2 + stop = time.time() + merge_time = stop - start + print "Merge time: %s" % str(merge_time) + + sequential = sequential_transforms( + sorted_prime, + passthrough_prime, + mavg_price_prime + ) + + start = time.time() + for message in sequential: + assert 1 + 1 == 2 + stop = time.time() + seq_time = stop - start + print "Sequential time: %s" % str(seq_time) + print "Merge/Seq: %s" % (str(merge_time/seq_time)) + + +# merged = merged_transforms(sorted, passthrough, mavg_price) + + # algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3]) +# environment = create_trading_environment(year = 2012) +# style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE + +# trading_client = tsc(algo, environment, style) + +# for message in trading_client.simulate(merged): +# pp(message) + diff --git a/zipline/gens/mavg.py b/zipline/gens/mavg.py new file mode 100644 index 00000000..21aa0bd0 --- /dev/null +++ b/zipline/gens/mavg.py @@ -0,0 +1,127 @@ +from numbers import Number +from datetime import datetime, timedelta +from collections import defaultdict + +from zipline import ndict +from zipline.gens.transform import EventWindow + +class MovingAverage(object): + """ + Class that maintains a dictionary from sids to + MovingAverageEventWindows. For each sid, we maintain moving + averages over any number of distinct fields (For example, we can + maintain a sid's average volume as well as its average price.) + """ + + def __init__(self, fields, market_aware, days = None, delta = None): + + self.fields = fields + self.market_aware = market_aware + + self.delta = delta + self.days = days + + # Market-aware mode only works with full-day windows. + if self.market_aware: + assert self.days and not self.delta,\ + "Market-aware mode only works with full-day windows." + + # Non-market-aware mode requires a timedelta. + else: + assert self.delta and not self.days, \ + "Non-market-aware mode requires a timedelta." + + # No way to pass arguments to the defaultdict factory, so we + # need to define a method to generate the correct EventWindows. + self.sid_windows = defaultdict(self.create_window) + + def create_window(self): + """ + Factory method for self.sid_windows. + """ + return MovingAverageEventWindow( + self.fields, + self.market_aware, + self.days, + self.delta + ) + + def update(self, event): + """ + Update the event window for this event's sid. Return an ndict + from tracked fields to moving averages. + """ + # This will create a new EventWindow if this is the first + # message for this sid. + window = self.sid_windows[event.sid] + window.update(event) + return window.get_averages() + +class MovingAverageEventWindow(EventWindow): + """ + Iteratively calculates moving averages for a particular sid over a + given time window. We can maintain averages for arbitrarily many + fields on a single sid. (For example, we might track average + price as well as average volume for a single sid.) The expected + functionality of this class is to be instantiated inside a + MovingAverage transform. + """ + + def __init__(self, fields, market_aware, days, delta): + + # Call the superclass constructor to set up base EventWindow + # infrastructure. + EventWindow.__init__(self, market_aware, days, delta) + + # We maintain a dictionary of totals for each of our tracked + # fields. + self.fields = fields + self.totals = defaultdict(float) + + # Subclass customization for adding new events. + def handle_add(self, event): + # Sanity check on the event. + self.assert_required_fields(event) + # Increment our running totals with data from the event. + for field in self.fields: + self.totals[field] += event[field] + + # Subclass customization for removing expired events. + def handle_remove(self, event): + # Decrement our running totals with data from the event. + for field in self.fields: + self.totals[field] -= event[field] + + def average(self, field): + """ + Calculate the average value of our ticks over a single field. + """ + # Sanity check. + assert field in self.fields + + # Averages are None by convention if we have no ticks. + if len(self.ticks) == 0: + return 0.0 + + # Calculate and return the average. len(self.ticks) is O(1). + else: + return self.totals[field] / len(self.ticks) + + def get_averages(self): + """ + Return an ndict of all our tracked averages. + """ + out = ndict() + for field in self.fields: + out[field] = self.average(field) + return out + + def assert_required_fields(self, event): + """ + We only allow events with all of our tracked fields. + """ + for field in self.fields: + assert event.has_key(field), \ + "Event missing [%s] in MovingAverageEventWindow" % field + assert isinstance(event[field], Number), \ + "Got %s for %s in MovingAverageEventWindow" % (event[field], field) diff --git a/zipline/gens/merge.py b/zipline/gens/merge.py index 4778ed5b..09e1f943 100644 --- a/zipline/gens/merge.py +++ b/zipline/gens/merge.py @@ -6,13 +6,13 @@ from collections import deque from zipline import ndict from zipline.gens.utils import hash_args, \ - assert_merge_protocol - + assert_merge_protocol, done_message +from itertools import repeat def merge(stream_in, tnfm_ids): """ - A generator that takes a generator and a list of source_ids. We - maintain an internal queue for each id in source_ids. Once we + A generator that takes a generator and a list of transform ids. We + maintain an internal queue for each id in tnfm_ids. Once we have a message from every queue, we pop an event from each queue and merge them together into an event. We raise an error if we do not receive the same number of events from all sources. @@ -28,22 +28,22 @@ def merge(stream_in, tnfm_ids): # Process incoming streams. for message in stream_in: - assert isinstance(message, tuple), \ - "Bad message in merge: %s" %message - assert len(message) == 2 - id, value = message + assert isinstance(message, ndict) + assert message.has_key('tnfm_id') + assert message.has_key('tnfm_value') + assert message.has_key('dt') + + id = message.tnfm_id assert id in tnfm_ids, \ "Message from unexpected tnfm: %s, %s" % (id, tnfm_ids) - assert isinstance(value, ndict), "Bad message in merge: %s" %message - tnfms[id].append(value) + tnfms[id].append(message) # Only pop messages when we have a pending message from # all datasources. Stop if all sources have signalled done. while ready(tnfms) and not done(tnfms): message = merge_one(tnfms) - assert_merge_protocol(tnfm_ids, message) yield message # We should have only a done message left in each queue. @@ -51,13 +51,25 @@ def merge(stream_in, tnfm_ids): assert len(queue) == 1, "Bad queue in merge on exit: %s" % queue assert queue[0].dt == "DONE", \ "Bad last message in merge on exit: %s" % queue - + def merge_one(sources): - output = ndict() + + event_fields = ndict() for key, queue in sources.iteritems(): - new_xform = ndict({key: queue.popleft()}) - output.merge(new_xform) - return output + + # Add transform value to the transforms dict. + message = queue.popleft() + event_fields[message.tnfm_id] = message.tnfm_value + del message['tnfm_id'] + del message['tnfm_value'] + + # Merge any remaining fields into the event dict. + event_fields.merge(message) + + # alias dt with datetime, per algoscript api + event_fields['datetime'] = event_fields['dt'] + + return event_fields #TODO: This is replicated in sort. Probably should be one source file. diff --git a/zipline/gens/returns.py b/zipline/gens/returns.py new file mode 100644 index 00000000..49d3e9b5 --- /dev/null +++ b/zipline/gens/returns.py @@ -0,0 +1,74 @@ +from collections import defaultdict, deque + +class Returns(object): + """ + Class that maintains a dictionary from sids to the sid's + closing price N trading days ago. + """ + def __init__(self, days): + self.days = days + self.mapping = defaultdict(self._create) + + def update(self, event): + """ + Update and return the calculated returns for this event's sid. + """ + assert event.has_key('dt') + assert event.has_key('price') + tracker = self.mapping[event.sid] + tracker.update(event) + + return tracker.get_returns() + + def _create(self): + return ReturnsFromPriorClose(self.days) + +class ReturnsFromPriorClose(object): + """ + Records the last N closing events for a given security as well as the + last event for the security. When we get an event for a new day, we + treat the last event seen as the close for the previous day. + """ + + def __init__(self, days): + self.closes = deque() + self.last_event = None + self.returns = 0.0 + self.days = days + + def get_returns(self): + return self.returns + + def update(self, event): + + if self.last_event: + + # Day has changed since the last event we saw. Treat + # the last event as the closing price for its day and + # clear out the oldest close if it has expired. + if self.last_event.dt.date() != event.dt.date(): + + self.closes.append(self.last_event) + + # We keep an event for the end of each trading day, so + # if the number of stored events is greater than the + # number of days we want to track, the oldest close + # is expired and should be discarded. + while len(self.closes) > self.days: + # Pop the oldest event. + self.closes.popleft() + + # We only generate a return value once we've seen enough days + # to give a sensible value. Would be nice if we could query + # db for closes prior to our initial event, but that would + # require giving this transform database creds, which we want + # to avoid. + + if len(self.closes) == self.days: + last_close = self.closes[0].price + change = event.price - last_close + self.returns = change / last_close + + + # the current event is now the last_event + self.last_event = event diff --git a/zipline/gens/sort.py b/zipline/gens/sort.py index f6ff7a5e..9755da74 100644 --- a/zipline/gens/sort.py +++ b/zipline/gens/sort.py @@ -14,7 +14,6 @@ def date_sort(stream_in, source_ids): have messages pending from all sources, we pull the earliest message and yield it. """ - assert isinstance(source_ids, (list, tuple)) # Set up an internal queue for each expected source. @@ -28,7 +27,7 @@ def date_sort(stream_in, source_ids): # Incoming messages should be the output of DATASOURCE_UNFRAME. assert_datasource_unframe_protocol(message), \ "Bad message in date_sort: %s" % message - + # Only allow messages from sources we expect. assert message.source_id in sources, "Unexpected source: %s" % message @@ -41,7 +40,7 @@ def date_sort(stream_in, source_ids): message = pop_oldest(sources) assert_sort_protocol(message) yield message - + # We should have only a done message left in each queue. for queue in sources.itervalues(): assert len(queue) == 1, "Bad queue in date_sort on exit: %s" % queue diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index c7ee74f8..2e8f6bea 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -3,13 +3,14 @@ Tools to generate trade events without a backing store. Useful for testing and zipline development """ import random +import pytz + from itertools import chain, cycle, ifilter, izip from datetime import datetime, timedelta -from zipline.utils.factory import create_trade -from zipline.gens.utils import hash_args, mock_done +from zipline.gens.utils import hash_args, create_trade -def date_gen(start = datetime(2012, 6, 6, 0), +def date_gen(start = datetime(2006, 6, 6, 12, tzinfo=pytz.utc), delta = timedelta(minutes = 1), count = 100): """ @@ -25,9 +26,9 @@ def mock_prices(count, rand = False): """ if rand: - return (random.uniform(0.0, 10.0) for i in xrange(count)) + return (random.uniform(1.0, 10.0) for i in xrange(count)) else: - return (float(i % 11) for i in xrange(1,count+1)) + return (float(i % 10) + 1.0 for i in xrange(count)) def mock_volumes(count, rand = False): """ @@ -49,72 +50,104 @@ def fuzzy_dates(count = 500): for date in date_gen(count = count): yield date + timedelta(seconds = random.randint(-10, 10)) -def SpecificEquityTrades(*args, **config): +class SpecificEquityTrades(object): """ Yields all events in event_list that match the given sid_filter. If no event_list is specified, generates an internal stream of events to filter. Returns all events if filter is None. + + Configuration options: + + count : integer representing number of trades + sids : list of values representing simulated internal sids + start : start date + delta : timedelta between internal events + filter : filter to remove the sids """ - # We shouldn't get any positional arguments. - assert args == () - # Unpack config dictionary with default values. - count = config.get('count', 500) - sids = config.get('sids', [1, 2]) - start = config.get('start', datetime(2012, 6, 6, 0)) - delta = config.get('delta', timedelta(minutes = 1)) + def __init__(self, *args, **kwargs): + # We shouldn't get any positional arguments. + assert len(args) == 0 - # Default to None for event_list and filter. - event_list = config.get('event_list') - filter = config.get('filter') + # Unpack config dictionary with default values. + self.count = kwargs.get('count', 500) + self.sids = kwargs.get('sids', [1, 2]) + self.start = kwargs.get('start', datetime(2008, 6, 6, 15, tzinfo = pytz.utc)) + self.delta = kwargs.get('delta', timedelta(minutes = 1)) - arg_string = hash_args(*args, **config) - namestring = "SpecificEquityTrades" + arg_string - # If we have an event_list, ignore the other arguments and use the list. - # TODO: still append our namestring? - if event_list: - unfiltered = (event for event in event_list) + # Default to None for event_list and filter. + self.event_list = kwargs.get('event_list') + self.filter = kwargs.get('filter') - # Set up iterators for each expected field. - else: - dates = date_gen(count = count, start = start, delta = delta) - prices = mock_prices(count) - volumes = mock_volumes(count) + # Hash_value for downstream sorting. + self.arg_string = hash_args(*args, **kwargs) + + self.generator = self.create_fresh_generator() + + def __iter__(self): + return self + + def next(self): + return self.generator.next() + + def rewind(self): + self.generator = self.create_fresh_generator() + + def get_hash(self): + return self.__class__.__name__ + "-" + self.arg_string + + def create_fresh_generator(self): + + if self.event_list: + unfiltered = (event for event in self.event_list) + + # Set up iterators for each expected field. + else: + dates = date_gen(count=self.count, + start=self.start, + delta=self.delta + ) + prices = mock_prices(self.count) + volumes = mock_volumes(self.count) + sids = cycle(self.sids) + + # Combine the iterators into a single iterator of arguments + arg_gen = izip(sids, prices, volumes, dates) + + # Convert argument packages into events. + unfiltered = (create_trade(*args, source_id = self.get_hash()) + for args in arg_gen) + + # If we specified a sid filter, filter out elements that don't + # match the filter. + if self.filter: + filtered = ifilter(lambda event: event.sid in self.filter, unfiltered) + + # Otherwise just use all events. + else: + filtered = unfiltered + + # Return the filtered event stream. + return filtered + + +# !!!!!!! Deprecated for now !!!!!!!!! + +def RandomEquityTrades(object): + + def __init__(self): + # We shouldn't get any positional args. + assert args == () + + self.count = config.get('count', 500) + self.sids = config.get('sids', [1,2]) + self.filter = config.get('filter') + + dates = fuzzy_dates(count) + prices = mock_prices(count, rand = True) + volumes = mock_volumes(count, rand = True) sids = cycle(sids) - # Combine the iterators into a single iterator of arguments - arg_gen = izip(sids, prices, volumes, dates) - - # Convert argument packages into events. - unfiltered = (create_trade(*args, source_id = namestring) - for args in arg_gen) - - # If we specified a sid filter, filter out elements that don't match the filter. - if filter: - filtered = ifilter(lambda event: event.sid in filter, unfiltered) - - # Otherwise just use all events. - else: - filtered = unfiltered - - # Add a done message to the end of the stream. For a live - # datasource this would be handled by the containing Component. - out = chain(filtered, [mock_done(namestring)]) - return out - -def RandomEquityTrades(*args, **config): - # We shouldn't get any positional args. - assert args == () - - count = config.get('count', 500) - sids = config.get('sids', [1,2]) - filter = config.get('filter') - - dates = fuzzy_dates(count) - prices = mock_prices(count, rand = True) - volumes = mock_volumes(count, rand = True) - sids = cycle(sids) - arg_gen = izip(sids, prices, volumes, dates) unfiltered = (create_trade(*args) for args in arg_gen) diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py new file mode 100644 index 00000000..5e56b4f0 --- /dev/null +++ b/zipline/gens/tradesimulation.py @@ -0,0 +1,307 @@ +from logbook import Logger, Processor + +from datetime import datetime, timedelta +from numbers import Integral + +from zipline import ndict + +from zipline.gens.transform import StatefulTransform +from zipline.finance.trading import TransactionSimulator +from zipline.finance.performance import PerformanceTracker +from zipline.utils.log_utils import stdout_only_pipe +from zipline.gens.utils import hash_args + +log = Logger('Trade Simulation') + +class TradeSimulationClient(object): + """ + Generator that takes the expected output of a merge, a user + algorithm, a trading environment, and a simulator style as + arguments. Pipes the merge stream through a TransactionSimulator + and a PerformanceTracker, which keep track of the current state of + our algorithm's simulated universe. Results are fed to the user's + algorithm, which directly inserts transactions into the + TransactionSimulator's order book. + + TransactionSimulator maintains a dictionary from sids to the + unfulfilled orders placed by the user's algorithm. As trade + events arrive, if the algorithm has open orders against the + trade's sid, the simulator will fill orders up to 25% of market + cap. Applied transactions are added to a txn field on the event + and forwarded to PerformanceTracker. The txn field is set to None + on non-trade events and events that do not match any open orders. + + PerformanceTracker receives the updated event messages from + TransactionSimulator, maintaining a set of daily and cumulative + performance metrics for the algorithm. The tracker removes the + txn field from each event it receives, replacing it with a + portfolio field to be fed into the user algo. At the end of each + trading day, the PerformanceTracker also generates a daily + performance report, which is appended to event's perf_report + field. + + Fully processed events are run through a batcher generator, which + batches together events with the same dt field into a single event + to be fed to the algo. The portfolio object is repeatedly + overwritten so that only the most recent snapshot of the universe + is sent to the algo. + """ + + def __init__(self, algo, environment, sim_style): + + self.algo = algo + self.sids = algo.get_sid_filter() + self.environment = environment + self.style = sim_style + self.algo_sim = None + + def get_hash(self): + """ + There should only ever be one TSC in the system. + """ + return self.__class__.__name__ + hash_args() + + def simulate(self, stream_in): + """ + Main generator work loop. + """ + + # Simulate filling any open orders made by the previous run of + # the user's algorithm. Sets the txn field to true on any + # event that results in a filled order. + ordering_client = StatefulTransform( + TransactionSimulator, + self.sids, + style = self.style + ) + with_filled_orders = ordering_client.transform(stream_in) + + # Pipe the events with transactions to perf. This will remove + # the txn field added by TransactionSimulator and replace it + # with a portfolio object to be passed to the user's + # algorithm. Also adds a perf_message field which is usually + # none, but contains an update message once per day. + perf_tracker = StatefulTransform( + PerformanceTracker, + self.environment, + self.sids + ) + with_portfolio = perf_tracker.transform(with_filled_orders) + + # Pass the messages from perf along with the trading client's + # state into the algorithm for simulation. We provide the + # trading client so that the algorithm can place new orders + # into the client's order book. + self.algo_sim = AlgorithmSimulator( + with_portfolio, + ordering_client.state, + self.algo, + ) + + # The algorithm will yield a daily_results message (as + # calculated by the performance tracker) at the end of each + # day. It will also yield a risk report at the end of the + # simulation. + for message in self.algo_sim: + yield message + +class AlgorithmSimulator(object): + + def __init__(self, stream_in, order_book, algo): + + self.stream_in = stream_in + + # ========== + # Algo Setup + # ========== + + # We extract the order book from the txn client so that + # the algo can place new orders. + self.order_book = order_book + + self.algo = algo + self.sids = algo.get_sid_filter() + + # Monkey patch the user algorithm to place orders in the + # TransactionSimulator's order book. + self.algo.set_order(self.order) + self.algo.set_logger(Logger("AlgoLog")) + + + # ============== + # Snapshot Setup + # ============== + + # The algorithm's universe as of our most recent event. + self.universe = ndict() + + for sid in self.sids: + self.universe[sid] = ndict() + self.universe.portfolio = None + + # We don't have a datetime for the current snapshot until we + # receive a message. + self.simulation_dt = None + self.this_snapshot_dt = None + + # ============= + # Logging Setup + # ============= + + # Processor function for injecting the algo_dt into + # user prints/logs. + def inject_algo_dt(record): + record.extra['algo_dt'] = self.this_snapshot_dt + self.processor = Processor(inject_algo_dt) + + # This is a class, which is instantiated later + # in run_algorithm. The class provides a generator. + self.stdout_capture = stdout_only_pipe + + self.__generator = None + + def __iter__(self): + return self + + def next(self): + if self.__generator: + return self.__generator.next() + else: + self.__generator = self._gen() + return self.__generator.next() + + def order(self, sid, amount): + """ + Closure to pass into the user's algo to allow placing orders + into the txn_sim's dict of open orders. + """ + assert sid in self.sids, "Order on invalid sid: %i" % sid + order = ndict({ + 'dt' : self.simulation_dt, + 'sid' : sid, + 'amount' : int(amount), + 'filled' : 0 + }) + + # Tell the user if they try to buy 0 shares of something. + if order.amount == 0: + zero_message = "Requested to trade zero shares of {sid}".format( + sid=order.sid + ) + log.debug(zero_message) + # Don't bother placing orders for 0 shares. + return + + # Add non-zero orders to the order book. + # !!!IMPORTANT SIDE-EFFECT!!! + # This modifies the internal state of the transaction + # simulator so that it can fill the placed order when it + # receives its next message. + self.order_book.place_order(order) + + def _gen(self): + """ + Internal generator work loop. + """ + # Capture any output of this generator to stdout and pipe it + # to a logbook interface. Also inject the current algo + # snapshot time to any log record generated. + + with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''): + # Call the user's initialize method. + self.algo.initialize() + + for event in self.stream_in: + # Yield any perf messages received to be relayed back to + # the browser. + + if event.perf_message: + yield event.perf_message + del event['perf_message'] + + if event.dt == "DONE": + if self.this_snapshot_dt: + # stop iteration happened + # mid-snapshot, so we have a universe + # snapshot that is not yet processed + # by the algorithm. + self.simulate_current_snapshot() + break + + # This should only happen for the first event we run. + if self.simulation_dt == None: + self.simulation_dt = event.dt + + # ====================== + # Time Compression Logic + # ====================== + + if self.this_snapshot_dt != None: + self.update_current_snapshot(event) + + # The algorithm has been missing events because it took + # too long processing. Update the universe with data from + # this event, then check if enough time has passed that we + # can start a new snapshot. + else: + self.update_universe(event) + if event.dt >= self.simulation_dt: + self.this_snapshot_dt = event.dt + + + + def update_current_snapshot(self, event): + """ + Update our current snapshot of the universe. Call handle_data if + """ + # The new event matches our snapshot dt. Just update the + # universe and move on. + if event.dt == self.this_snapshot_dt: + self.update_universe(event) + + # The new event does not match our snapshot. + else: + self.simulate_current_snapshot() + + # Once we've finished simulating the old snapshot, + # we can update the universe with the new event. + self.update_universe(event) + + # The current event is later than the simulation time, + # which means the algorithm finished quickly enough to + # receive the new event. Start a new snapshot with this + # event's dt. + if event.dt >= self.simulation_dt: + self.this_snapshot_dt = event.dt + + # The algorithm spent enough time processing that it + # missed the new event. Wait to start a new snapshot until + # the events catch up to the algo's simulated dt. + else: + self.this_snapshot_dt = None + + def simulate_current_snapshot(self): + """ + Run the user's algo against our current snapshot and update the algo's + simulated time. + """ + start_tic = datetime.now() + self.algo.handle_data(self.universe) + stop_tic = datetime.now() + + # How long did you take? + delta = stop_tic - start_tic + + # Update the simulation time. + self.simulation_dt = self.this_snapshot_dt + delta + + def update_universe(self, event): + """ + Update the universe with new event information. + """ + # Update our portfolio. + self.universe.portfolio = event.portfolio + + # Update our knowledge of this event's sid + for field in event.keys(): + self.universe[event.sid][field] = event[field] diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index a03a841a..651c337d 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -2,16 +2,21 @@ Generator versions of transforms. """ import types +import pytz -from datetime import datetime +from copy import deepcopy +from datetime import datetime, timedelta from collections import deque, defaultdict from numbers import Number +from abc import ABCMeta, abstractmethod from zipline import ndict +from zipline.utils.tradingcalendar import trading_days_between from zipline.gens.utils import assert_sort_unframe_protocol, \ assert_transform_protocol, hash_args class Passthrough(object): + FORWARDER = True """ Trivial class for forwarding events. """ @@ -19,11 +24,9 @@ class Passthrough(object): pass def update(self, event): - assert isinstance(event, ndict),"Bad event in Passthrough: %s" % event - assert event.has_key('sid'), "No sid in Passthrough: %s" % event - assert event.has_key('dt'), "No dt in Passthorughz: %s" % event - return event + pass +# Deprecated def functional_transform(stream_in, func, *args, **kwargs): """ Generic transform generator that takes each message from an in-stream @@ -40,131 +43,195 @@ def functional_transform(stream_in, func, *args, **kwargs): assert_transform_protocol(out_value) yield(namestring, out_value) -def stateful_transform(stream_in, tnfm_class, *args, **kwargs): +class StatefulTransform(object): """ - Generic transform generator that takes each message from an in-stream - and sorts it to a state class. For each call to update, the state - class must produce a message to be fed downstream. + Generic transform generator that takes each message from an + in-stream and passes it to a state object. For each call to + update, the state class must produce a message to be fed + downstream. Any transform class with the FORWARDER class variable + set to true will forward all fields in the original message. + Otherwise only dt, tnfm_id, and tnfm_value are forwarded. """ - - assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \ + def __init__(self, tnfm_class, *args, **kwargs): + assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \ "Stateful transform requires a class." - assert tnfm_class.__dict__.has_key('update'), \ + assert tnfm_class.__dict__.has_key('update'), \ "Stateful transform requires the class to have an update method" - # Create an instance of our transform class. - state = tnfm_class(*args, **kwargs) + self.forward_all = tnfm_class.__dict__.get('FORWARDER', False) + self.update_in_place = tnfm_class.__dict__.get('UPDATER', False) + self.append_value = tnfm_class.__dict__.get('APPENDER', False) - # Generate the string associated with this generator's output. - namestring = tnfm_class.__name__ + hash_args(*args, **kwargs) + # You only one special behavior mode can be set. + assert sum(map(int, [self.forward_all, + self.update_in_place, + self.append_value])) <= 1 - for message in stream_in: - assert_sort_unframe_protocol(message) - out_value = state.update(message) - assert_transform_protocol(out_value) - yield (namestring, out_value) + # Create an instance of our transform class. + self.state = tnfm_class(*args, **kwargs) -class MovingAverage(object): - """ - Class that maintains a dictionary from sids to EventWindows - Upon receipt of each message we update the - corresponding window and return the calculated average. + # Create the string associated with this generator's output. + self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs) + + def get_hash(self): + return self.namestring + + def transform(self, stream_in): + return self._gen(stream_in) + + def _gen(self, stream_in): + # IMPORTANT: Messages may contain pointers that are shared with + # other streams, so we only manipulate copies. + + for message in stream_in: + + # allow upstream generators to yield None to avoid + # blocking. + if message == None: + continue + + #TODO: refactor this to avoid unnecessary copying. + + assert_sort_unframe_protocol(message) + message_copy = deepcopy(message) + + # Same shared pointer issue here as above. + tnfm_value = self.state.update(deepcopy(message_copy)) + + # FORWARDER flag means we want to keep all original + # values, plus append tnfm_id and tnfm_value. Used for + # preserving the original event fields when our output + # will be fed into a merge. + if self.forward_all: + out_message = message_copy + out_message.tnfm_id = self.namestring + out_message.tnfm_value = tnfm_value + yield out_message + + # UPDATER flag should be used for transforms that + # side-effectfully modify the event they are passed. + # Updated messages are passed along exactly as they are + # returned to use by our state class. Useful for chaining + # specific transforms that won't be fed to a merge. (See + # the implementation of TradeSimulationClient for example + # usage of this flag with PerformanceTracker and + # TransactionSimulator. + elif self.update_in_place: + yield tnfm_value + + # APPENDER flag should be used to add a single new + # key-value pair to the event. The new key is this + # transform's namestring, and it's value is the value + # returned by state.update(event). This is almost + # identical to the behavior of FORWARDER, except we + # compress the two calculated values (tnfm_id, and + # tnfm_value) into a single field. This mode is used by + # the sequential_transforms composite. + elif self.append_value: + out_message = message_copy + out_message[self.namestring] = tnfm_value + yield out_message + + # If no flags are set, we create a new message containing + # just the tnfm_id, the event's datetime, and the + # calculated tnfm_value. This is the default behavior for + # a transform being fed into a merge. + else: + out_message = ndict() + out_message.tnfm_id = self.namestring + out_message.tnfm_value = tnfm_value + out_message.dt = message_copy.dt + yield out_message + +class EventWindow: """ + Abstract base class for transform classes that calculate iterative + metrics on events within a given timedelta. Maintains a list of + events that are within a certain timedelta of the most recent + tick. Calls self.handle_add(event) for each event added to the + window. Calls self.handle_remove(event) for each event removed + from the window. Subclass these methods along with init(*args, + **kwargs) to calculate metrics over the window. - def __init__(self, delta, fields): + The market_aware flag is used to toggle whether the eventwindow + calculates + + See zipline/gens/mavg.py and zipline/gens/vwap.py for example + implementations of moving average and volume-weighted average + price. + """ + # Mark this as an abstract base class. + __metaclass__ = ABCMeta + + def __init__(self, market_aware, days = None, delta = None): + + self.market_aware = market_aware + self.days = days self.delta = delta - self.fields = fields - # No way to pass arguments to the defaultdict factory, so we - # need to define a method to generate the correct EventWindows. - self.sid_windows = defaultdict(self.create_window) - - def create_window(self): - """Factory method for self.sid_windows.""" - return EventWindow(self.delta, self.fields) - - def update(self, event): - """ - Update the event window for this event's sid. Return an ndict from - tracked fields to averages. - """ - - assert isinstance(event, ndict),"Bad event in MovingAverage: %s" % event - assert event.has_key('sid'), "No sid in MovingAverage: %s" % event - assert event.has_key('dt'), "No dt in MovingAverage: %s" % event - - output = ndict({'sid': event.sid, 'dt': event.dt}) - # This will create a new EventWindow if this is the first - # message for this sid. - window = self.sid_windows[event.sid] - window.update(event) - averages = window.get_averages() - - # Return the calculated averages along with - output.merge(averages) - return output - -class EventWindow(object): - """ - Maintains a list of events that are within a certain timedelta - of the most recent tick. The expected use of this class is to - track events associated with a single sid. We provide simple - functionality for averages, but anything more complicated - should be handled by a containing class. - """ - - def __init__(self, delta, fields): self.ticks = deque() - self.delta = delta - self.fields = fields - self.totals = defaultdict(float) + + # Market-aware mode only works with full-day windows. + if self.market_aware: + assert self.days and not self.delta,\ + "Market-aware mode only works with full-day windows." + + # Non-market-aware mode requires a timedelta. + else: + assert self.delta and not self.days, \ + "Non-market-aware mode requires a timedelta." + + # Set the behavior for dropping events from the back of the + # event window. + if self.market_aware: + self.drop_condition = self.out_of_market_window + else: + self.drop_condition = self.out_of_delta + + @abstractmethod + def handle_add(self, event): + raise NotImplementedError() + + @abstractmethod + def handle_remove(self, event): + raise NotImplementedError() def __len__(self): return len(self.ticks) def update(self, event): self.assert_well_formed(event) + # Add new event and increment totals. self.ticks.append(event) - for field in self.fields: - self.totals[field] += event[field] - # We return a list of all out-of-range events we removed. - out_of_range = [] + # Subclasses should override handle_add to define behavior for + # adding new ticks. + self.handle_add(event) - # Clear out expired events, decrementing totals. - # newest oldest - # | | - # V V + # Clear out any expired events. drop_condition changes depending + # on whether or not we are running in market_aware mode. + # + # oldest newest + # | | + # V V + while self.drop_condition(self.ticks[0].dt, self.ticks[-1].dt): - while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta: - # popleft removes and returns ticks[0] + # popleft removes and returns the oldest tick in self.ticks popped = self.ticks.popleft() - # Decrement totals - for field in self.fields: - self.totals[field] -= popped[field] - # Add the popped element to the list of dropped events. - out_of_range.append(popped) - return out_of_range + # Subclasses should override handle_remove to define + # behavior for removing ticks. + self.handle_remove(popped) - def average(self, field): - assert field in self.fields - if len(self.ticks) == 0: - return 0.0 - else: - return self.totals[field] / len(self.ticks) + def out_of_market_window(self, oldest, newest): + return trading_days_between(oldest, newest) >= self.days - def get_averages(self): - """ - Return an ndict of all our tracked averages. - """ - out = ndict() - # out.ticks = len(self.ticks) - for field in self.fields: - out[field] = self.average(field) - return out + def out_of_delta(self, oldest, newest): + return (newest - oldest) >= self.delta + # All event windows expect to receive events with datetime fields + # that arrive in sorted order. def assert_well_formed(self, event): assert isinstance(event, ndict), "Bad event in EventWindow:%s" % event assert event.has_key('dt'), "Missing dt in EventWindow:%s" % event @@ -173,8 +240,3 @@ class EventWindow(object): # Something is wrong if new event is older than previous. assert event.dt >= self.ticks[-1].dt, \ "Events arrived out of order in EventWindow: %s -> %s" % (event, self.ticks[0]) - for field in self.fields: - assert event.has_key(field), \ - "Event missing [%s] in EventWindow" % field - assert isinstance(event[field], Number), \ - "Got %s for %s in EventWindow" % (event[field], field) diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index e2f859cb..1ac85df6 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -1,6 +1,7 @@ import pytz import numbers +from collections import OrderedDict from hashlib import md5 from datetime import datetime from itertools import izip_longest @@ -16,8 +17,16 @@ def mock_raw_event(sid, dt): } return event -def mock_done(source_id): - return ndict({'dt': "DONE", "source_id" : source_id, 'type' : 0}) +def mock_done(id): + return ndict({ + 'dt' : "DONE", + "source_id" : id, + 'tnfm_id' : id, + 'tnfm_value': None, + 'type' : DATASOURCE_TYPE.DONE + }) + +done_message = mock_done def alternate(g1, g2): """Specialized version of roundrobin for just 2 generators.""" @@ -27,16 +36,25 @@ def alternate(g1, g2): if e2 != None: yield e2 -def roundrobin(*args): +def roundrobin(sources, namestrings): """ Takes N generators, pulling one element off each until all inputs are empty. """ - for elem_tuple in izip_longest(*args): - for value in elem_tuple: - if value != None: - yield value + assert len(sources) == len(namestrings) + mapping = OrderedDict(zip(namestrings, sources)) + # While our generators have not been exhausted, pull elements + while mapping.keys() != []: + for namestring, source in mapping.iteritems(): + try: + message = source.next() + # allow sources to yield None to avoid blocking. + if message: + yield message + except StopIteration: + yield done_message(namestring) + del mapping[namestring] def hash_args(*args, **kwargs): """Define a unique string for any set of representable args.""" @@ -48,6 +66,25 @@ def hash_args(*args, **kwargs): hasher.update(combined) return hasher.hexdigest() +def create_trade(sid, price, amount, datetime, source_id = "test_factory"): + row = ndict({ + 'source_id' : source_id, + 'type' : DATASOURCE_TYPE.TRADE, + 'sid' : sid, + 'dt' : datetime, + 'price' : price, + 'volume' : amount + }) + return row + +def sum_true(bool_iterable): + """ + Takes an iterable of boolean values and returns the number of + those values that are True. + """ + return sum(map(int, bool_iterable)) + + def assert_datasource_protocol(event): """Assert that an event meets the protocol for datasource outputs.""" diff --git a/zipline/gens/vwap.py b/zipline/gens/vwap.py new file mode 100644 index 00000000..5a0947d8 --- /dev/null +++ b/zipline/gens/vwap.py @@ -0,0 +1,86 @@ +from numbers import Number +from datetime import datetime, timedelta +from collections import defaultdict + +from zipline import ndict +from zipline.gens.transform import EventWindow + +class VWAP(object): + """ + Class that maintains a dictionary from sids to VWAPEventWindows. + """ + def __init__(self, market_aware, delta=None, days=None): + + self.market_aware = market_aware + self.delta = delta + self.days = days + + # Market-aware mode only works with full-day windows. + if self.market_aware: + assert self.days and not self.delta,\ + "Market-aware mode only works with full-day windows." + + # Non-market-aware mode requires a timedelta. + else: + assert self.delta and not self.days, \ + "Non-market-aware mode requires a timedelta." + + # No way to pass arguments to the defaultdict factory, so we + # need to define a method to generate the correct EventWindows. + self.sid_windows = defaultdict(self.create_window) + + def create_window(self): + """Factory method for self.sid_windows.""" + return VWAPEventWindow( + self.market_aware, + days = self.days, + delta = self.delta + ) + + def update(self, event): + """ + Update the event window for this event's sid. Returns the + current vwap for the sid. + """ + # This will create a new EventWindow if this is the first + # message for this sid. + window = self.sid_windows[event.sid] + window.update(event) + return window.get_vwap() + +class VWAPEventWindow(EventWindow): + """ + Iteratively maintains a vwap for a single sid over a given + timedelta. + """ + def __init__(self, market_aware, days=None, delta=None): + EventWindow.__init__(self, market_aware, days, delta) + self.flux = 0.0 + self.totalvolume = 0.0 + + # Subclass customization for adding new events. + def handle_add(self, event): + # Sanity check on the event. + self.assert_required_fields(event) + self.flux += event.volume * event.price + self.totalvolume += event.volume + + # Subclass customization for removing expired events. + def handle_remove(self, event): + self.flux -= event.volume * event.price + self.totalvolume -= event.volume + + def get_vwap(self): + """ + Return the calculated vwap for this sid. + """ + # By convention, vwap is None if we have no events. + if len(self.ticks) == 0: + return None + else: + return (self.flux / self.totalvolume) + + # We need numerical price and volume to calculate a vwap. + def assert_required_fields(self, event): + assert isinstance(event.price, Number) + assert isinstance(event.volume, Number) diff --git a/zipline/gens/zmq_gens.py b/zipline/gens/zmq_gens.py index 524852a7..e60dae2b 100644 --- a/zipline/gens/zmq_gens.py +++ b/zipline/gens/zmq_gens.py @@ -2,15 +2,17 @@ import zmq import zipline.protocol as zp -def gen_from_zmq(poller, unframe): +def gen_from_zmq(poller, unframe, namestring): """ A generator that takes an initialized zmq poller and yields messages from the poller until it gets a zp.CONTROL_PROTOCOL.DONE. """ while True: message = poller.recv() - if message = zp.CONTROL_PROTOCOL.DONE: - yield "DONE" + # Done protocol should now be a message type so that + # done messages can also have source_ids. + if message.type == zp.CONTROL_PROTOCOL.DONE: + yield done_message(message.source_id) break else: yield unframe(message) diff --git a/zipline/gens/zmqgen.py b/zipline/gens/zmqgen.py new file mode 100644 index 00000000..66dbdca3 --- /dev/null +++ b/zipline/gens/zmqgen.py @@ -0,0 +1,19 @@ +import zmq +import zipline.protocol as zp + +def gen_from_pull_socket(socket_uri, context, unframe): + """ + A generator that takes a socket_uri, and yields + messages from the poller until it gets a zp.CONTROL_PROTOCOL.DONE. + """ + pull_socket = context.socket(zmq.PULL) + pull_socket.connect(socket_uri) + poller = zmq.Poller() + poller.register(pull_socket, zmq.POLLIN) + + return gen_from_poller(poller, pull_socket, unframe) + + +# this generator needs to know about the source_ids coming in via +# the poller, and need to yield DONE messages for each +# source_id. diff --git a/zipline/lines.py b/zipline/lines.py index 2dc43e2f..9bc8b3ac 100644 --- a/zipline/lines.py +++ b/zipline/lines.py @@ -59,106 +59,190 @@ before invoking simulate. | __init__. | +---------------------------------+ """ - -import inspect -import logbook -import zipline.utils.factory as factory - -from zipline.components import DataSource -from zipline.transforms import BaseTransform +import sys +import zmq +import os +from signal import SIGHUP, SIGINT +import multiprocessing +from setproctitle import setproctitle from zipline.test_algorithms import TestAlgorithm -from zipline.components import TradeSimulationClient -from zipline.core.process import ProcessSimulator -from zipline.core.monitor import Controller from zipline.finance.trading import SIMULATION_STYLE +from zipline.utils.log_utils import ZeroMQLogHandler, stdout_only_pipe +from zipline.utils import factory -log = logbook.Logger('Lines') +from zipline.test_algorithms import TestAlgorithm + +from zipline.gens.composites import \ + date_sorted_sources, merged_transforms, sequential_transforms +from zipline.gens.transform import Passthrough, StatefulTransform +from zipline.gens.tradesimulation import TradeSimulationClient as tsc +from logbook import Logger, NestedSetup, Processor + +import zipline.protocol as zp + +log = Logger('Lines') + + +class CancelSignal(Exception): + def __init__(self): + pass class SimulatedTrading(object): - """ - Zipline with:: - - _no_ data sources. - - Trade simulation client, which is available to send callbacks on - events and also accept orders to be simulated. - - An order data source, which will receive orders from the trade - simulation client, and feed them into the event stream to be - serialized and order alongside all other data source events. - - transaction simulation transformation, which receives the order - events and estimates a theoretical execution price and volume. + def __init__(self, + sources, + transforms, + algorithm, + environment, + style, + results_socket_uri, + context, + sim_id): - All components in this zipline are subject to heartbeat checks and - a control monitor, which can kill the entire zipline in the event of - exceptions in one of the components or an external request to end the - simulation. - """ + self.date_sorted = date_sorted_sources(*sources) + self.transforms = transforms + # Formerly merged_transforms. + self.with_tnfms = sequential_transforms(self.date_sorted, *self.transforms) + self.trading_client = tsc(algorithm, environment, style) + self.gen = self.trading_client.simulate(self.with_tnfms) + self.results_uri = results_socket_uri + self.results_socket = None + self.context = context + self.sim_id = sim_id - def __init__(self, **config): + # optional process if we fork simulate into an + # independent process. + self.proc = None + self.send_sighup = False + self.logger = Logger(sim_id) + self.print_logger = Logger('Print') + + # exit status flag + self.success = False + + + def simulate(self, blocking=True, send_sighup=False): + + # for non-blocking, + if blocking: + self.run_gen() + else: + self.send_sighup = send_sighup + return self.fork_and_sim() + + def fork_and_sim(self): + self.proc = multiprocessing.Process(target=self.run_gen) + self.proc.start() + return self.proc + + def run_gen(self): + setproctitle(self.sim_id) + self.open() + if self.zmq_out: + with self.zmq_out.threadbound(): + self.stream_results() + # if no log socket, just run the algo normally + else: + self.stream_results() + + def stream_results(self): + assert self.results_socket, \ + "Results socket must exist to stream results" + try: + for event in self.gen: + if event.has_key('daily_perf'): + msg = zp.PERF_FRAME(event) + else: + msg = zp.RISK_FRAME(event) + self.results_socket.send(msg) + + self.signal_done() + self.success = True + except Exception as exc: + self.handle_exception(exc) + finally: + # not much to do besides log our exit. + self.close() + + def signal_done(self): + # notify monitor we're done + done_frame = zp.DONE_FRAME('success') + self.results_socket.send(done_frame) + + def close(self): + log.info("Closing Simulation: {id}".format(id=self.sim_id)) + if self.proc and self.send_sighup: + ppid = os.getppid() + if self.success: + log.warning("Sending SIGHUP") + os.kill(ppid, SIGHUP) + else: + log.warning("Sending SIGINT") + os.kill(ppid, SIGINT) + + def handle_exception(self, exc): + if isinstance(exc, CancelSignal): + # signal from monitor of an orderly shutdown, + # do nothing. + pass + else: + self.signal_exception(exc) + + def signal_exception(self, exc=None): """ - :param config: a dict with the following required properties:: + All exceptions inside any component should boil back to + this handler. - - algorithm: a class that follows the algorithm protocol. See - :py:meth:`zipline.finance.trading.TradeSimulationClient.add_algorithm - for details. - - trading_environment: an instance of - :py:class:`zipline.trading.TradingEnvironment` - - allocator: an instance of - :py:class:`zipline.simulator.AddressAllocator` - - simulation_style: optional parameter that configures the - :py:class:`zipline.finance.trading.TransactionSimulator`. Expects - a SIMULATION_STYLE as defined in :py:mod:`zipline.finance.trading` + Will inform the system that the component has failed and how it + has failed. """ - assert isinstance(config, dict) - self.algorithm = config['algorithm'] - self.allocator = config['allocator'] - self.trading_environment = config['trading_environment'] - self.sim_style = config.get('simulation_style') - self.send_sighup = config.get('send_sighup', False) + exc_type, exc_value, exc_traceback = sys.exc_info() + try: + log.exception('{id} sending exception to result stream.'\ + .format(id=self.sim_id)) + msg = zp.EXCEPTION_FRAME( + exc_traceback, + exc_type.__name__, + exc_value.message + ) - self.leased_sockets = [] - self.sim_context = None + self.results_socket.send(msg) - sockets = self.allocate_sockets(7) - addresses = { - 'sync_address' : sockets[0], - 'data_address' : sockets[1], - 'feed_address' : sockets[2], - 'merge_address' : sockets[3], - # TODO: this refers to the results of the merge, a - # horribly confusing name for the socket. - 'results_address' : sockets[4], - } + except: + log.exception("Exception while reporting simulation exception.") - self.con = Controller( - sockets[5], - sockets[6], - self.send_sighup + def open(self): + if not self.context: + self.context = zmq.Context() + if self.results_uri: + sock = self.context.socket(zmq.PUSH) + sock.connect(self.results_uri) + self.results_socket = sock + self.setup_logging() + + def setup_logging(self): + assert self.results_socket + # The filter behavior is: matches are logged, mismatches + # are bubbled. If bubble is True, matches are also + # bubbled. Since we do not want user logs in our system + # logs, we set bubble to False. + self.zmq_out = ZeroMQLogHandler( + socket = self.results_socket, + filter = lambda r, h: r.channel in ['Print', 'AlgoLog'], + bubble=False ) - self.started = False + def join(self): + if self.proc: + self.proc.join() - self.sim = ProcessSimulator(addresses) - - self.clients = {} - - self.trading_client = TradeSimulationClient( - self.trading_environment, - self.sim_style, - config['results_socket'] - ) - self.add_client(self.trading_client) - - # setup all sources - self.sources = {} - - #setup transforms - self.transforms = {} - - self.sim.register_controller( self.con ) - - self.trading_client.set_algorithm(self.algorithm) + def get_pids(self): + if self.proc: + return [self.proc.pid] + else: + return [] @staticmethod def create_test_zipline(**config): @@ -167,7 +251,6 @@ class SimulatedTrading(object): - environment - a \ :py:class:`zipline.finance.trading.TradingEnvironment` - - allocator - a :py:class:`zipline.simulator.AddressAllocator` - sid - an integer, which will be used as the security ID. - order_count - the number of orders the test algo will place, defaults to 100 @@ -182,10 +265,10 @@ class SimulatedTrading(object): - simulation_style: optional parameter that configures the :py:class:`zipline.finance.trading.TransactionSimulator`. Expects a SIMULATION_STYLE as defined in :py:mod:`zipline.finance.trading` + - transforms: optional parameter that provides a list + of StatefulTransform objects. """ assert isinstance(config, dict) - - allocator = config['allocator'] sid = config['sid'] #-------------------- @@ -217,6 +300,10 @@ class SimulatedTrading(object): if not simulation_style: simulation_style = SIMULATION_STYLE.FIXED_SLIPPAGE + zmq_context = config.get('zmq_context', None) + simulation_id = config.get('simulation_id', 'test_simulation') + results_socket_uri = config.get('results_socket_uri', None) + #------------------- # Trade Source #------------------- @@ -230,6 +317,12 @@ class SimulatedTrading(object): trade_count, trading_environment ) + + #------------------- + # Transforms + #------------------- + transforms = config.get('transforms', []) + #------------------- # Create the Algo #------------------- @@ -242,157 +335,19 @@ class SimulatedTrading(object): order_count ) - if config.has_key('results_socket'): - results_socket = config['results_socket'] - else: - results_socket = None #------------------- # Simulation #------------------- - zipline = SimulatedTrading(**{ - 'algorithm' : test_algo, - 'trading_environment' : trading_environment, - 'allocator' : allocator, - 'simulation_style' : simulation_style, - 'results_socket' : results_socket, - }) + + sim = SimulatedTrading( + [trade_source], + transforms, + test_algo, + trading_environment, + simulation_style, + results_socket_uri, + zmq_context, + simulation_id) #------------------- - zipline.add_source(trade_source) - - return zipline - - def add_source(self, source): - """ - Adds the source to the zipline, sets the sid filter of the - source to the algorithm's sid filter. - """ - assert isinstance(source, DataSource) - self.check_started() - source.set_filter('sid', self.algorithm.get_sid_filter()) - self.sim.register_components([source]) - - # ``id`` is name of source_id, ``get_id`` is the class name - self.sources[source.get_id] = source - - def add_transform(self, transform): - assert isinstance(transform, BaseTransform) - self.check_started() - self.sim.register_components([transform]) - self.transforms[transform.get_id] = transform - - def add_client(self, client): - assert isinstance(client, TradeSimulationClient) - self.check_started() - self.sim.register_components([client]) - self.clients[client.get_id] = client - - def check_started(self): - if self.started: - raise ZiplineException("TradeSimulation", "You cannot add \ - components after the simulation has begun.") - - def get_cumulative_performance(self): - return self.trading_client.perf.cumulative_performance.to_dict() - - def allocate_sockets(self, n): - """ - Allocate sockets local to this line, track them so - we can gc after test run. - """ - - assert isinstance(n, int) - assert n > 0 - - leased = self.allocator.lease(n) - self.leased_sockets.extend(leased) - - return leased - - @property - def components(self): - """ - Return the component instances inside of this topology - """ - - base = set(self.sim.components.values()) - transforms = set(self.transforms.values()) - sources = set(self.sources.values()) - - return base | transforms | sources - - @property - def topology(self): - """ - Returns the Component names in the topology of the - backtest. - """ - - # A complete topology is the union of three classes of - # components added individually to the simulation client - # at various places. - # - # base : ['FEED', 'MERGE', 'TRADING_CLIENT', 'PASSTHROUGH'] - # transforms : ['vwap__01', ... ] - # sources : ['MongoTradeHistory', ... ] - - base = set(self.sim.components.keys()) - transforms = set(self.transforms.keys()) - sources = set(self.sources.keys()) - - return base | transforms | sources - - def setup_controller(self): - """ - Prepare the controller to manage the topology specified - by this line. - """ - self.con.manage(self.topology) - - def simulate(self, blocking=True): - self.setup_controller() - - self.started = True - self.sim_context = self.sim.simulate() - - # If we're using a threaded simulator block on the pool - # of thread since we're only ever in a test and we don't - # generally monitor the state of the system as a hold at - # the supervisory layer - - # TODO: better way of identifying concurrency substrate - if blocking: - for process in self.sim.subprocesses: - process.join() - - @property - def is_success(self): - # TODO: other assertions? - if self.sim.did_clean_shutdown(): - return True - else: - return False - - #-------------------------------- - # Component property accessors - #-------------------------------- - - def get_positions(self): - """ - returns current positions as a dict. draws from the cumulative - performance period in the performance tracker. - """ - perf = self.trading_client.perf.cumulative_performance - positions = perf.get_positions() - return positions - -class ZiplineException(Exception): - def __init__(self, zipline_name, msg): - self.name = zipline_name - self.message = msg - - def __str__(self): - return "Unexpected exception {line}: {msg}".format( - line=self.name, - msg=self.message - ) + return sim diff --git a/zipline/protocol.py b/zipline/protocol.py index 7aa503d7..dd46bd60 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -131,7 +131,7 @@ from utils.date_utils import EPOCH, UN_EPOCH, epoch_now # Control Protocol # ----------------------- -PRODUCTION_PREFIXES = ['PERF', 'RISK', 'EXCEPTION', 'CANCEL'] +PRODUCTION_PREFIXES = ['PERF', 'RISK', 'EXCEPTION','CANCEL','DONE', 'LOG'] INVALID_CONTROL_FRAME = FrameExceptionFactory('CONTROL') @@ -527,11 +527,15 @@ def EXCEPTION_FRAME(exception_tb, name, message): rlist = [] for stack in stack_list: filename = shorten_filename(stack[0]) + # default the line to empty string rather than None + line = '' + if stack[3]: + line = stack[3] rstack = { 'filename' : filename, 'lineno' : stack[1], 'method' : stack[2], - 'line' : stack[3] + 'line' : line } rlist.append(rstack) result = { @@ -570,6 +574,12 @@ def CANCEL_FRAME(date): return BT_UPDATE_FRAME('CANCEL', result) +def DONE_FRAME(msg): + assert isinstance(msg, basestring), \ + "Done message must be a string." + + return BT_UPDATE_FRAME('DONE', msg) + def BT_UPDATE_FRAME(prefix, payload): """ diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 68f3d160..a7881fa8 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -221,6 +221,32 @@ class DivByZeroAlgorithm(): def get_sid_filter(self): return [self.sid] +class TimeoutAlgorithm(): + + def __init__(self, sid): + self.sid = sid + self.incr = 0 + + def initialize(self): + pass + + def set_order(self, order_callable): + pass + + def set_logger(self, logger): + pass + + def set_portfolio(self, portfolio): + pass + + def handle_data(self, data): + if self.incr > 4: + import time + time.sleep(100) + pass + + def get_sid_filter(self): + return [self.sid] class TestPrintAlgorithm(): diff --git a/zipline/utils/date_utils.py b/zipline/utils/date_utils.py index a1fbfad1..2819d4e9 100644 --- a/zipline/utils/date_utils.py +++ b/zipline/utils/date_utils.py @@ -95,7 +95,7 @@ HOLIDAYS = { 'july_4th' : datetime(2008 , 7 , 4 ), 'labor_day' : datetime(2008 , 9 , 1 ), 'tgiving' : datetime(2008 , 11 , 27), - 'christmas' : datetime(2008 , 5 , 25), + 'christmas' : datetime(2008 , 12 , 25), } # Create a rule to recur every weekday starting today diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index db440891..e3d92443 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -12,7 +12,9 @@ from datetime import datetime, timedelta import zipline.finance.risk as risk import zipline.protocol as zp -from zipline.finance.sources import SpecificEquityTrades, RandomEquityTrades +from zipline.gens.tradegens import RandomEquityTrades +from zipline.gens.tradegens import SpecificEquityTrades +from zipline.gens.utils import create_trade from zipline.finance.trading import TradingEnvironment # TODO @@ -69,16 +71,6 @@ def create_trading_environment(year=2006): return trading_environment -def create_trade(sid, price, amount, datetime, source_id = "test_factory"): - row = zp.ndict({ - 'source_id' : source_id, - 'type' : zp.DATASOURCE_TYPE.TRADE, - 'sid' : sid, - 'dt' : datetime, - 'price' : price, - 'volume' : amount - }) - return row def get_next_trading_dt(current, interval, trading_calendar): next = current @@ -89,8 +81,6 @@ def get_next_trading_dt(current, interval, trading_calendar): return next - - def create_trade_history(sid, prices, amounts, interval, trading_calendar): trades = [] current = trading_calendar.first_open @@ -222,29 +212,19 @@ def create_minutely_trade_source(sids, trade_count, trading_environment): ) def create_trade_source(sids, trade_count, trade_time_increment, trading_environment): - trade_history = [] - price = [10.1] * trade_count - volume = [100] * trade_count + args = tuple() + kwargs = { + 'count' : trade_count, + 'sids' : sids, + 'start' : trading_environment.first_open, + 'delta' : trade_time_increment, + 'filter' : sids + } + source = SpecificEquityTrades(*args, **kwargs) - for sid in sids: - start_date = trading_environment.first_open + # TODO: do we need to set the trading environment's end to same dt as + # the last trade in the history? + #trading_environment.period_end = trade_history[-1].dt - generated_trades = create_trade_history( - sid, - price, - volume, - trade_time_increment, - trading_environment - ) - - trade_history.extend(generated_trades) - - trade_history = sorted(trade_history, key=attrgetter('dt')) - - #set the trading environment's end to same dt as the last trade in the - #history. - trading_environment.period_end = trade_history[-1].dt - - source = SpecificEquityTrades(trade_history) return source diff --git a/zipline/utils/log_utils.py b/zipline/utils/log_utils.py index 6bcf80f8..f9fbc57c 100644 --- a/zipline/utils/log_utils.py +++ b/zipline/utils/log_utils.py @@ -120,12 +120,6 @@ class ZeroMQLogHandler(Handler): #can't be serialized by JSON, so we need to convert to #unix epoch representation. - if record.time: - assert isinstance(record.time, datetime.datetime) - - time = record.time.replace(tzinfo = pytz.utc) - #logbook measures time in utc already, no need to convert. - record.time = EPOCH(time) #Do the same if algo_dt is a datetime object. if record.extra.has_key('algo_dt'): @@ -151,6 +145,14 @@ class ZeroMQLogHandler(Handler): data[field] = record.extra[field] else: data[field] = None + + if data['time']: + assert isinstance(data['time'], datetime.datetime) + + time = data['time'].replace(tzinfo = pytz.utc) + #logbook measures time in utc already, no need to convert. + data['time'] = EPOCH(time) + return data def emit(self, record): diff --git a/zipline/utils/test_utils.py b/zipline/utils/test_utils.py index 6655c265..2dfcff59 100644 --- a/zipline/utils/test_utils.py +++ b/zipline/utils/test_utils.py @@ -1,11 +1,14 @@ +import multiprocessing import zmq import time import zipline.protocol as zp from datetime import datetime import blist +from bson import ObjectId from zipline.utils.date_utils import EPOCH from itertools import izip from logbook import FileHandler +from zipline.core.monitor import Monitor def setup_logger(test, path='/var/log/zipline/zipline.log'): test.log_handler = FileHandler(path) @@ -29,6 +32,7 @@ def check_dict(test, a, b, label): # ignore the extra fields used by dictshield if key in ['progress']: continue + test.assertTrue(a.has_key(key), "missing key at: " + label + "." + key) test.assertTrue(b.has_key(key), "missing key at: " + label + "." + key) a_val = a[key] @@ -58,15 +62,22 @@ def check(test, a, b, label=None): else: test.assertEqual(a, b, "mismatch on path: " + label) +def check_excluded(test, a, excluded_keys=[]): + for key, value in a.iteritems(): + test.assertTrue(key not in excluded_keys) + test.assertFalse(key.endswith('_id'), 'Avoid _id fields!') + test.assertFalse(isinstance(value, ObjectId)) + if isinstance(value, dict): + check_excluded(test, value, excluded_keys) -def drain_zipline(test, zipline): +def drain_zipline(test, zipline, p_blocking=False): assert test.ctx, "method expects a valid zmq context" assert test.zipline_test_config, "method expects a valid test config" assert isinstance(test.zipline_test_config, dict) - assert test.zipline_test_config['results_socket'], \ + assert test.zipline_test_config['results_socket_uri'], \ "need to specify a socket address for logs/perf/risk" test.receiver = create_receiver( - test.zipline_test_config['results_socket'], + test.zipline_test_config['results_socket_uri'], test.ctx ) # Bind and connect are asynch, so allow time for bind before @@ -74,13 +85,12 @@ def drain_zipline(test, zipline): time.sleep(1) # start the simulation - zipline.simulate(blocking=False) + zipline.simulate(blocking=p_blocking) output, transaction_count = drain_receiver(test.receiver) # some processes will exit after the message stream is # finished. We block here to avoid collisions with subsequent # ziplines. - for process in zipline.sim.subprocesses: - process.join() + zipline.join() return output, transaction_count @@ -95,16 +105,15 @@ def drain_receiver(receiver): transaction_count = 0 while True: msg = receiver.recv() - if msg == str(zp.CONTROL_PROTOCOL.DONE): + update = zp.BT_UPDATE_UNFRAME(msg) + output.append(update) + if update['prefix'] == 'PERF': + transaction_count += \ + len(update['payload']['daily_perf']['transactions']) + elif update['prefix'] == 'EXCEPTION': + break + elif update['prefix'] == 'DONE': break - else: - update = zp.BT_UPDATE_UNFRAME(msg) - output.append(update) - if update['prefix'] == 'PERF': - transaction_count += \ - len(update['payload']['daily_perf']['transactions']) - elif update['prefix'] == 'EXCEPTION': - break receiver.close() del receiver @@ -115,9 +124,6 @@ def drain_receiver(receiver): def assert_single_position(test, zipline): output, transaction_count = drain_zipline(test, zipline) - test.assertTrue(zipline.sim.ready()) - test.assertFalse(zipline.sim.exception) - test.assertEqual( test.zipline_test_config['order_count'], transaction_count @@ -126,7 +132,8 @@ def assert_single_position(test, zipline): # the final message is the risk report, the second to # last is the final day's results. Positions is a list of # dicts. - closing_positions = output[-2]['payload']['daily_perf']['positions'] + perfs = [x for x in output if x['prefix'] == 'PERF'] + closing_positions = perfs[-2]['payload']['daily_perf']['positions'] test.assertEqual( len(closing_positions), @@ -140,3 +147,58 @@ def assert_single_position(test, zipline): sid, "Portfolio should have one position in " + str(sid) ) + + +def launch_component(component): + proc = multiprocessing.Process(target=component.run) + proc.start() + return proc + +def launch_monitor(monitor): + proc = multiprocessing.Process(target=monitor.run) + proc.start() + return proc + + +def create_monitor(allocator): + sockets = allocator.lease(3) + mon = Monitor( + # pub socket + sockets[0], + # route socket + sockets[1], + # exception socket to match tradesimclient's result + # socket, because we want to relay exceptions to the + # same listener + sockets[2], + # this controller is expected to run in a test, so no + # need to signal the parent process on success or error. + send_sighup=False + ) + + return mon + +class ExceptionSource(object): + + def __init__(self): + pass + + def get_hash(self): + return "ExceptionSource" + + def __iter__(self): + return self + + def next(self): + 5 / 0 + +class ExceptionTransform(object): + + def __init__(self): + pass + + def get_hash(self): + return "ExceptionTransform" + + def update(self, event): + assert False, "An assertion message" diff --git a/zipline/utils/tradingcalendar.py b/zipline/utils/tradingcalendar.py new file mode 100644 index 00000000..f760e51e --- /dev/null +++ b/zipline/utils/tradingcalendar.py @@ -0,0 +1,358 @@ +import pytz + +from datetime import datetime, timedelta +from dateutil import rrule +from zipline.utils.date_utils import utcnow + +def market_opens(start, end, inclusive=False): + """ + Returns all market opens between the start date and the end date. + Must use utc-stamped datetimes. + """ + return opens.between(start, end, inc=inclusive) + +def market_closes(start, end, inclusive=False): + """ + Returns all market closes between the start date and the end date. + Must use utc-stamped datetimes. + """ + return closes.between(start, end, inc=inclusive) + +def trading_days_between(start, end): + """ + Calculate the number of "complete" trading days between two + events. We define this as the number of market opens that + occurred between start and end, with the caveat that we subtract 1 + from this total if end falls on the same day as the last market + open and end occurs earlier in its own day than start. This + reflects the fact that we haven't completed a full day + corresponding to the last market open. + + Examples: + + 1.) + start = Tuesday, Aug 7, 2012, 1:00 pm + end = Wednesday, Aug 8, 2012, 1:30 pm + + There is one market open between these dates, on the morning of + Wednesday the 8th. This falls on the same calendar day as end, + but end is later in the day than start, so we count this as a full + day. The correct output is 1. + + 2.) + start = Tuesday, Aug 7, 2012, 1:30 pm + end = Wednesday, Aug 8, 2012, 1:00 pm + + There is one market open between these dayes, on the morning of + Wednesday the 8th. This falls on the same calendar day as end, + and end is earlier in the day than start, so we do not count this + day as completed. The correct output is 0. + + 3.) + start = Tuesday, Aug 7, 2012, 1:00 pm + end = Saturday, Aug 11, 2012, 1:30 pm + + There are 3 market opens between these dates, occurring on + Wednesday, Thursday, and Friday. The last open is not on + the same day as end, so we simply return 3 + + 4.) + start = Tuesday, Aug 7, 2012, 1:30 pm + end = Monday, Aug, 13, 2012, 1:00 pm + + There are 4 market opens between these dates, occurring on + Wednesday, Thursday, Friday, and the following Monday. The + last open occurs on the same calendar day as end, and end + is earlier in the day than start, so we do not count the + last market day as completed. The correct output is 3 days. + """ + # Calculate the number of opens between the events. + opens = (market_opens(start, end)) + days_between = len(opens) + if days_between == 0: + return days_between + + # If end falls on the same day as an open, subtract 1 from the + # total if end is earlier in its respective day than start. + last_open = opens[-1] + if last_open.date() == end.date() and earlier_in_day(end, start): + days_between -=1 + + return days_between + +def earlier_in_day(d1, d2): + """ + Return true if d1 falls earlier in its own day than d2. + """ + return d1.time() < d2.time() + +WEEKDAYS = [rrule.MO, rrule.TU, rrule.WE, rrule.TH, rrule.FR] + +# Recurrence rule that generates all market opens since Jan 1, 1970. +# This does not exclude holidays. +market_opens_with_holidays = rrule.rrule( + rrule.DAILY, + byweekday=WEEKDAYS, + byhour = 14, + byminute = 30, + cache = True, + dtstart=datetime(2000, 1, 1, tzinfo = pytz.utc), + until=datetime(2014 , 1, 1, tzinfo = pytz.utc) +) + +# Recurrence rule that generates all market closes since Jan 1, 1970. +# This does not exclude holidays. +market_closes_with_holidays = rrule.rrule( + rrule.DAILY, + byweekday=WEEKDAYS, + byhour = 21, + byminute = 0, + cache = True, + dtstart=datetime(2001, 1, 1, tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) + +# Recurrence rules for excluding the market open/close on new years. +new_years_opens = rrule.rrule( + rrule.MONTHLY, + byyearday = 1, + byhour = 14, + byminute = 30, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) +new_years_closes = rrule.rrule( + rrule.MONTHLY, + byyearday = 1, + byhour = 21, + byminute = 0, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) + +# Recurrence rules for excluding MLK day. It is always the third +# monday in January. +mlk_opens = rrule.rrule( + rrule.MONTHLY, + bymonth = 1, + byweekday = (rrule.MO(3)), + byhour = 14, + byminute = 30, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) +mlk_closes = rrule.rrule( + rrule.MONTHLY, + bymonth = 1, + byweekday = (rrule.MO(+3)), + byhour = 21, + byminute = 0, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) + +# Recurrence rules for generating the market open/close for +# presidents' day. Presidents' day always occurs on the third monday +# of February. +presidents_day_opens = rrule.rrule( + rrule.MONTHLY, + bymonth = 2, + byweekday = (rrule.MO(3)), + byhour = 14, + byminute = 30, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) +presidents_day_closes = rrule.rrule( + rrule.MONTHLY, + bymonth = 2, + byweekday = (rrule.MO(3)), + byhour = 21, + byminute = 0, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) + +# Recurrence rules for generating the market open/close for good +# friday. Good friday always falls 2 days before easter, which +# thankfully is a built-in refernce in this module. +good_friday_opens = rrule.rrule( + rrule.DAILY, + byeaster = -2, + byhour = 14, + byminute = 30, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) +good_friday_closes = rrule.rrule( + rrule.DAILY, + byeaster = -2, + byhour = 21, + byminute = 0, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) + +# Recurrence rules for generating the market open/close for memorial +# day. Memorial day always occurs on the last monday of May. +memorial_day_opens = rrule.rrule( + rrule.MONTHLY, + bymonth = 5, + byweekday = (rrule.MO(-1)), + byhour = 14, + byminute = 30, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) +memorial_day_closes = rrule.rrule( + rrule.MONTHLY, + bymonth = 5, + byweekday = (rrule.MO(-1)), + byhour = 21, + byminute = 0, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) + +# Recurrence rules for generating the market open/close for July 4th. +july_4th_opens = rrule.rrule( + rrule.MONTHLY, + bymonth = 6, + bymonthday = 4, + byhour = 14, + byminute = 30, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) +july_4th_closes = rrule.rrule( + rrule.MONTHLY, + bymonth = 6, + bymonthday = 4, + byhour = 21, + byminute = 0, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) + +# Recurrence rule for generating the market open/close for labor day. +# Labor day is always the first monday of September. +labor_day_opens = rrule.rrule( + rrule.MONTHLY, + bymonth = 9, + byweekday = (rrule.MO(1)), + byhour = 14, + byminute = 30, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) +labor_day_closes = rrule.rrule( + rrule.MONTHLY, + bymonth = 9, + byweekday = (rrule.MO(1)), + byhour = 21, + byminute = 0, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) + +# Recurrence rule for generating the market open/close for +# thanksgiving. Thanksgiving always falls on the fourth thursday in +# November. (Who decides how these holidays work!?!) +thanksgiving_opens = rrule.rrule( + rrule.MONTHLY, + bymonth = 11, + byweekday = (rrule.TH(-1)), + byhour = 14, + byminute = 30, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) +thanksgiving_closes = rrule.rrule( + rrule.MONTHLY, + bymonth = 11, + byweekday = (rrule.TH(-1)), + byhour = 21, + byminute = 0, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) + +# Recurrence relation for generating the market open/close for +# christmas. Christmas always occurs on december 25th. + +christmas_opens = rrule.rrule( + rrule.MONTHLY, + bymonth = 12, + bymonthday = 25, + byhour = 14, + byminute = 30, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) +christmas_closes = rrule.rrule( + rrule.MONTHLY, + bymonth = 12, + bymonthday = 25, + byhour = 21, + byminute = 0, + cache = True, + dtstart = datetime(2000, 1,1,tzinfo = pytz.utc), + until=datetime(2014, 1, 1, tzinfo = pytz.utc) +) + +# All NYSE observed holidays. +holiday_opens = [ + new_years_opens, + mlk_opens, + presidents_day_opens, + good_friday_opens, + memorial_day_opens, + july_4th_opens, + labor_day_opens, + thanksgiving_opens, + christmas_opens +] +holiday_closes = [ + new_years_closes, + mlk_closes, + presidents_day_closes, + good_friday_closes, + memorial_day_closes, + july_4th_closes, + labor_day_closes, + thanksgiving_closes, + christmas_closes +] + +# Valid market opens are given by all market opens minus holidays. +opens = rrule.rruleset(cache=True) +opens.rrule(market_opens_with_holidays) +for holiday_rule in holiday_opens: + opens.exrule(holiday_rule) + +closes = rrule.rruleset(cache=True) +closes.rrule(market_closes_with_holidays) +for holiday_rule in holiday_closes: + closes.exrule(holiday_rule) + +# This runs the calendar to load all data into a cache. +open_count = opens.count() +close_count = closes.count() +