from contextlib import contextmanager from itertools import ( product, ) from logbook import FileHandler from mock import patch from numpy.testing import assert_array_equal import operator from zipline.finance.blotter import ORDER_STATUS from zipline.utils import security_list from six import ( itervalues, ) from six.moves import filter import os import pandas as pd import shutil import tempfile EPOCH = pd.Timestamp(0, tz='UTC') def seconds_to_timestamp(seconds): return pd.Timestamp(seconds, unit='s', tz='UTC') def to_utc(time_str): """Convert a string in US/Eastern time to UTC""" return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC') def str_to_seconds(s): """ Convert a pandas-intelligible string to (integer) seconds since UTC. >>> from pandas import Timestamp >>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds() 1388534400.0 >>> str_to_seconds('2014-01-01') 1388534400 """ return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds()) def setup_logger(test, path='test.log'): test.log_handler = FileHandler(path) test.log_handler.push_application() def teardown_logger(test): test.log_handler.pop_application() test.log_handler.close() def drain_zipline(test, zipline): output = [] transaction_count = 0 msg_counter = 0 # start the simulation for update in zipline: msg_counter += 1 output.append(update) if 'daily_perf' in update: transaction_count += \ len(update['daily_perf']['transactions']) return output, transaction_count def assert_single_position(test, zipline): output, transaction_count = drain_zipline(test, zipline) if 'expected_transactions' in test.zipline_test_config: test.assertEqual( test.zipline_test_config['expected_transactions'], transaction_count ) else: test.assertEqual( test.zipline_test_config['order_count'], transaction_count ) # the final message is the risk report, the second to # last is the final day's results. Positions is a list of # dicts. closing_positions = output[-2]['daily_perf']['positions'] # confirm that all orders were filled. # iterate over the output updates, overwriting # orders when they are updated. Then check the status on all. orders_by_id = {} for update in output: if 'daily_perf' in update: if 'orders' in update['daily_perf']: for order in update['daily_perf']['orders']: orders_by_id[order['id']] = order for order in itervalues(orders_by_id): test.assertEqual( order['status'], ORDER_STATUS.FILLED, "") test.assertEqual( len(closing_positions), 1, "Portfolio should have one position." ) sid = test.zipline_test_config['sid'] test.assertEqual( closing_positions[0]['sid'], sid, "Portfolio should have one position in " + str(sid) ) return output, transaction_count class ExceptionSource(object): def __init__(self): pass def get_hash(self): return "ExceptionSource" def __iter__(self): return self def next(self): 5 / 0 def __next__(self): 5 / 0 @contextmanager def security_list_copy(): old_dir = security_list.SECURITY_LISTS_DIR new_dir = tempfile.mkdtemp() try: for subdir in os.listdir(old_dir): shutil.copytree(os.path.join(old_dir, subdir), os.path.join(new_dir, subdir)) with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \ patch.object(security_list, 'using_copy', True, create=True): yield finally: shutil.rmtree(new_dir, True) def add_security_data(adds, deletes): if not hasattr(security_list, 'using_copy'): raise Exception('add_security_data must be used within ' 'security_list_copy context') directory = os.path.join( security_list.SECURITY_LISTS_DIR, "leveraged_etf_list/20150127/20150125" ) if not os.path.exists(directory): os.makedirs(directory) del_path = os.path.join(directory, "delete") with open(del_path, 'w') as f: for sym in deletes: f.write(sym) f.write('\n') add_path = os.path.join(directory, "add") with open(add_path, 'w') as f: for sym in adds: f.write(sym) f.write('\n') def all_pairs_matching_predicate(values, pred): """ Return an iterator of all pairs, (v0, v1) from values such that `pred(v0, v1) == True` Parameters ---------- values : iterable pred : function Returns ------- pairs_iterator : generator Generator yielding pairs matching `pred`. Examples -------- >>> from zipline.utils.test_utils import all_pairs_matching_predicate >>> from operator import eq, lt >>> list(all_pairs_matching_predicate(range(5), eq)) [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] >>> list(all_pairs_matching_predicate("abcd", lt)) [('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')] """ return filter(lambda pair: pred(*pair), product(values, repeat=2)) def product_upper_triangle(values, include_diagonal=False): """ Return an iterator over pairs, (v0, v1), drawn from values. If `include_diagonal` is True, returns all pairs such that v0 <= v1. If `include_diagonal` is False, returns all pairs such that v0 < v1. """ return all_pairs_matching_predicate( values, operator.le if include_diagonal else operator.lt, ) def all_subindices(index): """ Return all valid sub-indices of a pandas Index. """ return ( index[start:stop] for start, stop in product_upper_triangle(range(len(index) + 1)) ) def make_rotating_asset_info(num_assets, first_start, frequency, periods_between_starts, asset_lifetime): """ Create a DataFrame representing lifetimes of assets that are constantly rotating in and out of existence. Parameters ---------- num_assets : int How many assets to create. first_start : pd.Timestamp The start date for the first asset. frequency : str or pd.tseries.offsets.Offset (e.g. trading_day) Frequency used to interpret next two arguments. periods_between_starts : int Create a new asset every `frequency` * `periods_between_new` asset_lifetime : int Each asset exists for `frequency` * `asset_lifetime` days. Returns ------- info : pd.DataFrame DataFrame representing newly-created assets. """ return pd.DataFrame( { 'sid': range(num_assets), 'symbol': [chr(ord('A') + i) for i in range(num_assets)], 'asset_type': ['equity'] * num_assets, # Start a new asset every `periods_between_starts` days. 'start_date': pd.date_range( first_start, freq=(periods_between_starts * frequency), periods=num_assets, ), # Each asset lasts for `asset_lifetime` days. 'end_date': pd.date_range( first_start + (asset_lifetime * frequency), freq=(periods_between_starts * frequency), periods=num_assets, ), 'exchange': 'TEST', } ) def make_simple_asset_info(assets, start_date, end_date, symbols=None): """ Create a DataFrame representing assets that exist for the full duration between `start_date` and `end_date`. Parameters ---------- assets : array-like start_date : pd.DatetimeIndex end_date : pd.DatetimeIndex symbols : list, optional Symbols to use for the assets. If not provided, symbols are generated from upper-case letters. Returns ------- info : pd.DataFrame DataFrame representing newly-created assets. """ num_assets = len(assets) if symbols is None: symbols = [chr(ord('A') + i) for i in range(num_assets)] return pd.DataFrame( { 'sid': assets, 'symbol': symbols, 'asset_type': ['equity'] * num_assets, 'start_date': [start_date] * num_assets, 'end_date': [end_date] * num_assets, 'exchange': 'TEST', } ) def check_arrays(left, right, err_msg='', verbose=True): """ Wrapper around np.assert_array_equal that also verifies that inputs are ndarrays. See Also -------- np.assert_array_equal """ if type(left) != type(right): raise AssertionError("%s != %s" % (type(left), type(right))) return assert_array_equal(left, right, err_msg=err_msg, verbose=True) class UnexpectedAttributeAccess(Exception): pass class ExplodingObject(object): """ Object that will raise an exception on any attribute access. Useful for verifying that an object is never touched during a function/method call. """ def __getattribute__(self, name): raise UnexpectedAttributeAccess(name)