diff --git a/.travis.yml b/.travis.yml index 442a14db..43cc9de2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,7 @@ language: python python: - "2.7" + - "3.3" before_install: - wget -O ta-lib-0.4.0-src.tar.gz http://sourceforge.net/projects/ta-lib/files/ta-lib/0.4.0/ta-lib-0.4.0-src.tar.gz/download - tar xvzf ta-lib-0.4.0-src.tar.gz diff --git a/tests/finance/test_slippage.py b/tests/finance/test_slippage.py index 7f0871cc..403d8b46 100644 --- a/tests/finance/test_slippage.py +++ b/tests/finance/test_slippage.py @@ -329,7 +329,7 @@ class SlippageTestCase(TestCase): slippage_model = VolumeShareSlippage() try: - _, txn = slippage_model.simulate(event, [order]).next() + _, txn = next(slippage_model.simulate(event, [order])) except StopIteration: txn = None diff --git a/tests/risk/answer_key.py b/tests/risk/answer_key.py index a1eb8ecf..5a9615a1 100644 --- a/tests/risk/answer_key.py +++ b/tests/risk/answer_key.py @@ -22,6 +22,8 @@ import pytz import xlrd import requests +from six.moves import map + def col_letter_to_index(col_letter): # Only supports single letter, @@ -51,12 +53,12 @@ LATEST_ANSWER_KEY_URL = ANSWER_KEY_DL_TEMPLATE.format( def answer_key_signature(): - with open(ANSWER_KEY_PATH, 'r') as f: + with open(ANSWER_KEY_PATH, 'rb') as f: md5 = hashlib.md5() - while True: + buf = f.read(1024) + md5.update(buf) + while buf != b"": buf = f.read(1024) - if not buf: - break md5.update(buf) return md5.hexdigest() @@ -288,7 +290,8 @@ class AnswerKey(object): def get_values(self, data_index): value_parser = self.value_type_to_value_func[data_index.value_type] - return map(value_parser, self.get_raw_values(data_index)) + return [value for value in + map(value_parser, self.get_raw_values(data_index))] ANSWER_KEY = AnswerKey() diff --git a/tests/risk/test_risk_cumulative.py b/tests/risk/test_risk_cumulative.py index 97be2f7b..16fb3a9a 100644 --- a/tests/risk/test_risk_cumulative.py +++ b/tests/risk/test_risk_cumulative.py @@ -23,7 +23,7 @@ from zipline.utils import factory from zipline.finance.trading import SimulationParameters -import answer_key +from . import answer_key ANSWER_KEY = answer_key.ANSWER_KEY diff --git a/tests/test_events_through_risk.py b/tests/test_events_through_risk.py index 6026c4ad..5e141ca0 100644 --- a/tests/test_events_through_risk.py +++ b/tests/test_events_through_risk.py @@ -293,7 +293,7 @@ class TestEventsThroughRisk(unittest.TestCase): crm = algo.perf_tracker.cumulative_risk_metrics - first_msg = gen.next() + first_msg = next(gen) self.assertIsNotNone(first_msg, "There should be a message emitted.") @@ -310,7 +310,7 @@ class TestEventsThroughRisk(unittest.TestCase): crm.metrics.algorithm_volatility[algo.datetime.date()], "On the first day algorithm volatility does not exist.") - second_msg = gen.next() + second_msg = next(gen) self.assertIsNotNone(second_msg, "There should be a message " "emitted.") @@ -325,7 +325,7 @@ class TestEventsThroughRisk(unittest.TestCase): crm.algorithm_returns[-1], decimal=6) - third_msg = gen.next() + third_msg = next(gen) self.assertEqual(1, len(algo.portfolio.positions), "Number of positions should stay the same.") diff --git a/tests/test_examples.py b/tests/test_examples.py index 1bbbd4ab..0ae919a6 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -21,15 +21,22 @@ import matplotlib matplotlib.use('Agg') -from os import path import os +from os import path + +try: + from path import walk +except ImportError: + # Assume Python 3 + from os import walk + import fnmatch import imp def test_examples(): os.chdir(example_dir()) - for fname in all_matching_files('.', '*.py'): + for fname in all_matching_files(example_dir(), '*.py'): yield check_example, fname @@ -40,7 +47,8 @@ def all_matching_files(d, pattern): fls.extend(nfiles) files = [] - path.walk(d, addfiles, files) + for dirpath, dirnames, filenames in walk(d): + addfiles(files, dirpath, filenames) return files diff --git a/tests/test_exception_handling.py b/tests/test_exception_handling.py index 6d24908f..0002fd8a 100644 --- a/tests/test_exception_handling.py +++ b/tests/test_exception_handling.py @@ -57,14 +57,9 @@ class ExceptionTestCase(TestCase): **self.zipline_test_config ) - with self.assertRaises(ZeroDivisionError) as ctx: + with self.assertRaises(ZeroDivisionError): output, _ = drain_zipline(self, zipline) - self.assertEqual( - ctx.exception.message, - 'integer division or modulo by zero' - ) - def test_tranform_exception(self): exc_tnfm = StatefulTransform(ExceptionTransform) self.zipline_test_config['transforms'] = [exc_tnfm] @@ -76,7 +71,7 @@ class ExceptionTestCase(TestCase): with self.assertRaises(AssertionError) as ctx: output, _ = drain_zipline(self, zipline) - self.assertEqual(ctx.exception.message, + self.assertEqual(str(ctx.exception), 'An assertion message') def test_exception_in_handle_data(self): @@ -96,7 +91,7 @@ class ExceptionTestCase(TestCase): with self.assertRaises(Exception) as ctx: output, _ = drain_zipline(self, zipline) - self.assertEqual(ctx.exception.message, + self.assertEqual(str(ctx.exception), 'Algo exception in handle_data') def test_zerodivision_exception_in_handle_data(self): @@ -113,12 +108,9 @@ class ExceptionTestCase(TestCase): **self.zipline_test_config ) - with self.assertRaises(ZeroDivisionError) as ctx: + with self.assertRaises(ZeroDivisionError): output, _ = drain_zipline(self, zipline) - self.assertEqual(ctx.exception.message, - 'integer division or modulo by zero') - def test_set_portfolio(self): """ Are we protected against overwriting an algo's portfolio? @@ -136,8 +128,5 @@ class ExceptionTestCase(TestCase): **self.zipline_test_config ) - with self.assertRaises(AttributeError) as ctx: + with self.assertRaises(AttributeError): output, _ = drain_zipline(self, zipline) - - self.assertEqual(ctx.exception.message, - "can't set attribute") diff --git a/tests/test_finance.py b/tests/test_finance.py index 689edd77..9a63815b 100644 --- a/tests/test_finance.py +++ b/tests/test_finance.py @@ -28,6 +28,8 @@ import numpy as np from nose.tools import timed +from six.moves import range + import zipline.protocol from zipline.protocol import Event, DATASOURCE_TYPE @@ -314,7 +316,7 @@ class FinanceTestCase(TestCase): alternator = 1 order_date = start_date - for i in xrange(order_count): + for i in range(order_count): blotter.set_date(order_date) blotter.order(sid, order_amount * alternator ** i, None, None) @@ -334,7 +336,7 @@ class FinanceTestCase(TestCase): order_list = oo[sid] self.assertEqual(order_count, len(order_list)) - for i in xrange(order_count): + for i in range(order_count): order = order_list[i] self.assertEqual(order.sid, sid) self.assertEqual(order.amount, order_amount * alternator ** i) @@ -372,7 +374,7 @@ class FinanceTestCase(TestCase): self.assertEqual(len(transactions), len(order_list)) total_volume = 0 - for i in xrange(len(transactions)): + for i in range(len(transactions)): txn = transactions[i] total_volume += txn.amount if complete_fill: diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index 09388820..296dd1ff 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -24,6 +24,8 @@ import datetime import pytz import itertools +from six.moves import range + import zipline.utils.factory as factory import zipline.finance.performance as perf from zipline.finance.slippage import Transaction, create_transaction @@ -47,7 +49,9 @@ tradingday = datetime.timedelta(hours=6, minutes=30) def create_txn(event, price, amount): mock_order = Order(None, None, event.sid, id=None) - return create_transaction(event, mock_order, price, amount) + txn = create_transaction(event, mock_order, price, amount) + txn.source_id = 'MockTransactionSource' + return txn def benchmark_events_in_range(sim_params): @@ -67,12 +71,10 @@ def calculate_results(host, events): perf_tracker = perf.PerformanceTracker(host.sim_params) - all_events = heapq.merge( - ((event.dt, event) for event in events), - ((event.dt, event) for event in host.benchmark_events)) + all_events = date_sorted_sources(events, host.benchmark_events) - filtered_events = [(date, filt_event) for (date, filt_event) - in all_events if date <= events[-1].dt] + filtered_events = [(filt_event.dt, filt_event) for filt_event + in all_events if filt_event.dt <= events[-1].dt] filtered_events.sort(key=lambda x: x[0]) grouped_events = itertools.groupby(filtered_events, lambda x: x[0]) results = [] @@ -431,7 +433,7 @@ class TestDividendPerformance(unittest.TestCase): pay_date = self.sim_params.first_open # find pay date that is much later. - for i in xrange(30): + for i in range(30): pay_date = factory.get_next_trading_dt(pay_date, oneday) dividend = factory.create_dividend( 1, @@ -1133,12 +1135,10 @@ class TestPerformanceTracker(unittest.TestCase): orders = [event for event in events if event.type == DATASOURCE_TYPE.ORDER] - all_events = (msg[1] for msg in heapq.merge( - ((event.dt, event) for event in events), - ((event.dt, event) for event in benchmark_events))) + all_events = date_sorted_sources(events, benchmark_events) filtered_events = [filt_event for filt_event - in all_events if event.dt <= end_dt] + in all_events if filt_event.dt <= end_dt] filtered_events.sort(key=lambda x: x.dt) grouped_events = itertools.groupby(filtered_events, lambda x: x.dt) perf_messages = [] @@ -1170,6 +1170,7 @@ class TestPerformanceTracker(unittest.TestCase): amount=-25, dt=event.dt ) + order.source_id = 'MockOrderSource' yield order yield event txn = Transaction( @@ -1180,6 +1181,7 @@ class TestPerformanceTracker(unittest.TestCase): commission=0.50, order_id=order.id ) + txn.source_id = 'MockTransactionSource' yield txn else: yield event diff --git a/tests/test_sources.py b/tests/test_sources.py index e1096131..1ba2e937 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -16,6 +16,8 @@ import pandas as pd import pytz from itertools import cycle +from six import integer_types + from unittest import TestCase import zipline.utils.factory as factory @@ -29,7 +31,7 @@ class TestDataFrameSource(TestCase): assert isinstance(source.end, pd.lib.Timestamp) for expected_dt, expected_price in df.iterrows(): - sid0 = source.next() + sid0 = next(source) assert expected_dt == sid0.dt assert expected_price[0] == sid0.price @@ -71,5 +73,5 @@ class TestDataFrameSource(TestCase): for event in source: for check_field in check_fields: self.assertIn(check_field, event) - self.assertTrue(isinstance(event['volume'], (int, long))) - self.assertEqual(stocks_iter.next(), event['sid']) + self.assertTrue(isinstance(event['volume'], (integer_types))) + self.assertEqual(next(stocks_iter), event['sid']) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 43df06d8..2da00e19 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -20,6 +20,8 @@ import pandas as pd from datetime import timedelta, datetime from unittest import TestCase +from six.moves import range + from zipline.utils.test_utils import setup_logger from zipline.protocol import Event @@ -64,7 +66,7 @@ class TestEventWindow(TestCase): self.monday = datetime(2012, 7, 9, 16, tzinfo=pytz.utc) self.eleven_normal_days = [self.monday + i * timedelta(days=1) - for i in xrange(11)] + for i in range(11)] # Modify the end of the period slightly to exercise the # incomplete day logic. @@ -75,7 +77,7 @@ class TestEventWindow(TestCase): # Second set of dates to test holiday handling. self.jul4_monday = datetime(2012, 7, 2, 16, tzinfo=pytz.utc) self.week_of_jul4 = [self.jul4_monday + i * timedelta(days=1) - for i in xrange(5)] + for i in range(5)] def test_market_aware_window_normal_week(self): window = NoopEventWindow( diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 9cb42ae5..8d5baf77 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -20,7 +20,9 @@ import numpy as np from datetime import datetime -from itertools import groupby, ifilter +from itertools import groupby +from six.moves import filter +from six import iteritems from operator import attrgetter from zipline.errors import ( @@ -194,7 +196,7 @@ class TradingAlgorithm(object): date_sorted = date_sorted_sources(*self.sources) if source_filter: - date_sorted = ifilter(source_filter, date_sorted) + date_sorted = filter(source_filter, date_sorted) with_tnfms = sequential_transforms(date_sorted, *self.transforms) @@ -305,7 +307,7 @@ class TradingAlgorithm(object): # Create transforms by wrapping them into StatefulTransforms self.transforms = [] - for namestring, trans_descr in self.registered_transforms.iteritems(): + for namestring, trans_descr in iteritems(self.registered_transforms): sf = StatefulTransform( trans_descr['class'], *trans_descr['args'], diff --git a/zipline/data/benchmarks.py b/zipline/data/benchmarks.py index bc938c0c..9f89f47f 100644 --- a/zipline/data/benchmarks.py +++ b/zipline/data/benchmarks.py @@ -23,6 +23,8 @@ from functools import partial import requests import pandas as pd +from six import iteritems + from . loader_utils import ( date_conversion, source_to_records, @@ -50,7 +52,7 @@ _BENCHMARK_MAPPING = { def benchmark_mappings(): return {key: Mapping(*value) for key, value - in _BENCHMARK_MAPPING.iteritems()} + in iteritems(_BENCHMARK_MAPPING)} def get_raw_benchmark_data(start_date, end_date, symbol): @@ -86,7 +88,7 @@ start_date={start_date}, end_date={end_date}, url={url}""".strip(). end_date=end_date, url=res.url)) - return csv.DictReader(res.iter_lines()) + return csv.DictReader(res.text.splitlines()) def get_benchmark_data(symbol, start_date=None, end_date=None): diff --git a/zipline/data/loader.py b/zipline/data/loader.py index 64333bb3..08d724d9 100644 --- a/zipline/data/loader.py +++ b/zipline/data/loader.py @@ -26,6 +26,8 @@ import pandas as pd from pandas.io.data import DataReader import pytz +from six import iteritems + from . import benchmarks from . benchmarks import get_benchmark_returns @@ -60,7 +62,7 @@ INDEX_MAPPING = { } -def get_datafile(name, mode='r'): +def get_data_filepath(name): """ Returns a handle to data file. @@ -70,7 +72,7 @@ def get_datafile(name, mode='r'): if not os.path.exists(DATA_PATH): os.makedirs(DATA_PATH) - return open(os.path.join(DATA_PATH, name), mode) + return os.path.join(DATA_PATH, name) def get_cache_filepath(name): @@ -100,9 +102,8 @@ def dump_treasury_curves(module='treasuries', filename='treasury_curves.csv'): curves = pd.DataFrame(tr_data).T - datafile = get_datafile(filename, mode='wb') - curves.to_csv(datafile) - datafile.close() + data_filepath = get_data_filepath(filename) + curves.to_csv(data_filepath) return curves @@ -119,10 +120,9 @@ def dump_benchmarks(symbol): benchmark = (daily_return.date, daily_return.returns) benchmark_data.append(benchmark) - datafile = get_datafile(get_benchmark_filename(symbol), mode='wb') + data_filepath = get_data_filepath(get_benchmark_filename(symbol)) benchmark_returns = pd.Series(dict(benchmark_data)) - benchmark_returns.to_csv(datafile) - datafile.close() + benchmark_returns.to_csv(data_filepath) def update_benchmarks(symbol, last_date): @@ -133,9 +133,8 @@ def update_benchmarks(symbol, last_date): Puts source benchmark into zipline. """ - datafile = get_datafile(get_benchmark_filename(symbol), mode='rb') + datafile = get_data_filepath(get_benchmark_filename(symbol)) saved_benchmarks = pd.Series.from_csv(datafile) - datafile.close() try: start = last_date + timedelta(days=1) @@ -144,9 +143,8 @@ def update_benchmarks(symbol, last_date): benchmark = pd.Series({daily_return.date: daily_return.returns}) saved_benchmarks = saved_benchmarks.append(benchmark) - datafile = get_datafile(get_benchmark_filename(symbol), mode='wb') + datafile = get_data_filepath(get_benchmark_filename(symbol)) saved_benchmarks.to_csv(datafile) - datafile.close() except benchmarks.BenchmarkDataNotFoundError as exc: logger.warn(exc) return saved_benchmarks @@ -157,19 +155,18 @@ def get_benchmark_filename(symbol): def load_market_data(bm_symbol='^GSPC'): + bm_filepath = get_data_filepath(get_benchmark_filename(bm_symbol)) try: - fp_bm = get_datafile(get_benchmark_filename(bm_symbol), "rb") - except IOError: + saved_benchmarks = pd.Series.from_csv(bm_filepath) + except OSError: print(""" data files aren't distributed with source. Fetching data from Yahoo Finance. """.strip()) dump_benchmarks(bm_symbol) - fp_bm = get_datafile(get_benchmark_filename(bm_symbol), "rb") + saved_benchmarks = pd.Series.from_csv(bm_filepath) - saved_benchmarks = pd.Series.from_csv(fp_bm) saved_benchmarks = saved_benchmarks.tz_localize('UTC') - fp_bm.close() most_recent = pd.Timestamp('today', tz='UTC') - trading_day most_recent_index = trading_days.searchsorted(most_recent) @@ -205,17 +202,16 @@ Fetching data from Yahoo Finance. module, filename, source = INDEX_MAPPING.get( bm_symbol, INDEX_MAPPING['^GSPC']) + tr_filepath = get_data_filepath(filename) try: - fp_tr = get_datafile(filename, "rb") - except IOError: + saved_curves = pd.DataFrame.from_csv(tr_filepath) + except OSError: print(""" data files aren't distributed with source. Fetching data from {0} """.format(source).strip()) dump_treasury_curves(module, filename) - fp_tr = get_datafile(filename, "rb") - - saved_curves = pd.DataFrame.from_csv(fp_tr) + saved_curves = pd.DataFrame.from_csv(tr_filepath) # Find the offset of the last date for which we have trading data in our # list of valid trading days @@ -236,10 +232,8 @@ Fetching data from {0} # tzinfo=pytz.utc) tr_curves[tr_dt] = curve.to_dict() - fp_tr.close() - tr_curves = OrderedDict(sorted( - ((dt, c) for dt, c in tr_curves.iteritems()), + ((dt, c) for dt, c in iteritems(tr_curves)), key=lambda t: t[0])) return benchmark_returns, tr_curves @@ -291,7 +285,7 @@ must specify stocks or indexes""" data[stock] = stkd if indexes is not None: - for name, ticker in indexes.iteritems(): + for name, ticker in iteritems(indexes): print(name) stkd = DataReader(ticker, 'yahoo', start, end).sort_index() data[name] = stkd @@ -327,7 +321,7 @@ def load_from_yahoo(indexes=None, close_key = 'Adj Close' else: close_key = 'Close' - df = pd.DataFrame({key: d[close_key] for key, d in data.iteritems()}) + df = pd.DataFrame({key: d[close_key] for key, d in iteritems(data)}) df.index = df.index.tz_localize(pytz.utc) return df diff --git a/zipline/data/loader_utils.py b/zipline/data/loader_utils.py index 7ff58406..014a95cb 100644 --- a/zipline/data/loader_utils.py +++ b/zipline/data/loader_utils.py @@ -30,6 +30,8 @@ from collections import namedtuple from functools import partial +from six import iteritems + def get_utc_from_exchange_time(naive): local = pytz.timezone('US/Eastern') @@ -126,7 +128,7 @@ def _row_cb(mapping, row): return { target: apply_mapping(mapping, row) for target, mapping - in mapping.iteritems() + in iteritems(mapping) } diff --git a/zipline/data/treasuries.py b/zipline/data/treasuries.py index a2b325e7..d0e96889 100644 --- a/zipline/data/treasuries.py +++ b/zipline/data/treasuries.py @@ -21,6 +21,8 @@ import requests from collections import OrderedDict import xml.etree.ElementTree as ET +from six import iteritems + from . loader_utils import ( guarded_conversion, safe_int, @@ -61,7 +63,7 @@ _CURVE_MAPPINGS = { def treasury_mappings(mappings): return {key: Mapping(*value) for key, value - in mappings.iteritems()} + in iteritems(mappings)} class iter_to_stream(object): @@ -96,7 +98,7 @@ def get_treasury_source(): http://data.treasury.gov/feed.svc/DailyTreasuryYieldCurveRateData\ """ res = requests.get(url, stream=True) - stream = iter_to_stream(res.iter_lines()) + stream = iter_to_stream(res.text.splitlines()) elements = ET.iterparse(stream, ('end', 'start-ns', 'end-ns')) diff --git a/zipline/finance/performance/period.py b/zipline/finance/performance/period.py index 00368a39..0f3113d1 100644 --- a/zipline/finance/performance/period.py +++ b/zipline/finance/performance/period.py @@ -77,6 +77,8 @@ import numpy as np import pandas as pd from collections import OrderedDict, defaultdict +from six import iteritems, itervalues + import zipline.protocol as zp from . position import positiondict @@ -164,7 +166,7 @@ class PerformancePeriod(object): payment has been disbursed. """ cash_payments = 0.0 - for sid, pos in self.positions.iteritems(): + for sid, pos in iteritems(self.positions): cash_payments += pos.update_dividends(todays_date) # credit our cash balance with the dividend payments, or @@ -307,7 +309,7 @@ class PerformancePeriod(object): else: transactions = \ [y.to_dict() - for x in self.processed_transactions.itervalues() + for x in itervalues(self.processed_transactions) for y in x] rval['transactions'] = transactions @@ -315,9 +317,9 @@ class PerformancePeriod(object): if dt: # only include orders modified as of the given dt. orders = [x.to_dict() - for x in self.orders_by_modified[dt].itervalues()] + for x in itervalues(self.orders_by_modified[dt])] else: - orders = [x.to_dict() for x in self.orders_by_id.itervalues()] + orders = [x.to_dict() for x in itervalues(self.orders_by_id)] rval['orders'] = orders return rval @@ -352,7 +354,7 @@ class PerformancePeriod(object): positions = self._positions_store - for sid, pos in self.positions.iteritems(): + for sid, pos in iteritems(self.positions): if sid not in positions: positions[sid] = zp.Position(sid) position = positions[sid] @@ -364,7 +366,7 @@ class PerformancePeriod(object): def get_positions_list(self): positions = [] - for sid, pos in self.positions.iteritems(): + for sid, pos in iteritems(self.positions): if pos.amount != 0: positions.append(pos.to_dict()) return positions diff --git a/zipline/finance/risk/cumulative.py b/zipline/finance/risk/cumulative.py index a91841da..47a12bb4 100644 --- a/zipline/finance/risk/cumulative.py +++ b/zipline/finance/risk/cumulative.py @@ -24,6 +24,8 @@ import zipline.utils.math_utils as zp_math import pandas as pd from pandas.tseries.tools import normalize_date +from six import iteritems + from . risk import ( alpha, check_entry, @@ -359,7 +361,7 @@ algorithm_returns ({algo_count}) in range {start} : {end} on {dt}" return {k: None if check_entry(k, v) - else v for k, v in rval.iteritems()} + else v for k, v in iteritems(rval)} def __repr__(self): statements = [] diff --git a/zipline/finance/risk/period.py b/zipline/finance/risk/period.py index 2e7867f5..08882672 100644 --- a/zipline/finance/risk/period.py +++ b/zipline/finance/risk/period.py @@ -20,11 +20,13 @@ import math import numpy as np import numpy.linalg as la +from six import iteritems + from zipline.finance import trading import pandas as pd -import risk +from . import risk from . risk import ( alpha, check_entry, @@ -131,7 +133,7 @@ class RiskMetricsPeriod(object): } return {k: None if check_entry(k, v) else v - for k, v in rval.iteritems()} + for k, v in iteritems(rval)} def __repr__(self): statements = [] diff --git a/zipline/finance/slippage.py b/zipline/finance/slippage.py index 25d396f9..1a7db006 100644 --- a/zipline/finance/slippage.py +++ b/zipline/finance/slippage.py @@ -20,6 +20,9 @@ import math from copy import copy from functools import partial + +from six import with_metaclass + from zipline.protocol import DATASOURCE_TYPE import zipline.utils.math_utils as zp_math @@ -152,9 +155,7 @@ def create_transaction(event, order, price, amount): return transaction -class SlippageModel(object): - - __metaclass__ = abc.ABCMeta +class SlippageModel(with_metaclass(abc.ABCMeta)): @property def volume_for_bar(self): diff --git a/zipline/gens/composites.py b/zipline/gens/composites.py index df2f1677..67c3bd66 100644 --- a/zipline/gens/composites.py +++ b/zipline/gens/composites.py @@ -15,6 +15,8 @@ import heapq +from six.moves import reduce + def _decorate_source(source): for message in source: diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index 5cfdc01b..9bd8f2c5 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -21,23 +21,24 @@ from hashlib import md5 from datetime import datetime from zipline.protocol import DATASOURCE_TYPE +from six import iteritems, b + def hash_args(*args, **kwargs): """Define a unique string for any set of representable args.""" arg_string = '_'.join([str(arg) for arg in args]) kwarg_string = '_'.join([str(key) + '=' + str(value) - for key, value in kwargs.iteritems()]) + for key, value in iteritems(kwargs)]) combined = ':'.join([arg_string, kwarg_string]) hasher = md5() - hasher.update(combined) + hasher.update(b(combined)) return hasher.hexdigest() def assert_datasource_protocol(event): """Assert that an event meets the protocol for datasource outputs.""" - assert isinstance(event.source_id, basestring) assert event.type in DATASOURCE_TYPE # Done packets have no dt. @@ -59,17 +60,14 @@ def assert_trade_protocol(event): def assert_datasource_unframe_protocol(event): """Assert that an event is valid output of zp.DATASOURCE_UNFRAME.""" - assert isinstance(event.source_id, basestring) assert event.type in DATASOURCE_TYPE def assert_sort_protocol(event): """Assert that an event is valid input to zp.FEED_FRAME.""" - assert isinstance(event.source_id, basestring) assert event.type in DATASOURCE_TYPE def assert_sort_unframe_protocol(event): """Same as above.""" - assert isinstance(event.source_id, basestring) assert event.type in DATASOURCE_TYPE diff --git a/zipline/protocol.py b/zipline/protocol.py index d2a65f16..44825d17 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six import iteritems, iterkeys + from . utils.protocol_utils import Enum # Datasource type should completely determine the other fields of a @@ -171,7 +173,7 @@ class BarData(object): del self._data[name] def __iter__(self): - for sid, data in self._data.iteritems(): + for sid, data in iteritems(self._data): # Allow contains override to filter out sids. if sid in self: if len(data): @@ -179,25 +181,25 @@ class BarData(object): def iterkeys(self): # Allow contains override to filter out sids. - return (sid for sid in self._data.iterkeys() if sid in self) + return (sid for sid in iterkeys(self._data) if sid in self) def keys(self): # Allow contains override to filter out sids. return list(self.iterkeys()) def itervalues(self): - return (value for sid, value in self.iteritems()) + return (value for sid, value in iteritems(self)) def values(self): return list(self.itervalues()) def iteritems(self): return ((sid, value) for sid, value - in self._data.iteritems() + in iteritems(self._data) if sid in self) def items(self): - return list(self.iteritems()) + return list(iteritems(self)) def __len__(self): return len(self.keys()) diff --git a/zipline/sources/data_source.py b/zipline/sources/data_source.py index 3b2c21ba..1b04734e 100644 --- a/zipline/sources/data_source.py +++ b/zipline/sources/data_source.py @@ -3,13 +3,13 @@ from abc import ( abstractproperty ) +from six import with_metaclass + from zipline.protocol import DATASOURCE_TYPE from zipline.protocol import Event -class DataSource(object): - - __metaclass__ = ABCMeta +class DataSource(with_metaclass(ABCMeta)): @property def event_type(self): @@ -62,3 +62,6 @@ class DataSource(object): def next(self): return self.mapped_data.next() + + def __next__(self): + return next(self.mapped_data) diff --git a/zipline/sources/test_source.py b/zipline/sources/test_source.py index 099e31a7..a25cec90 100644 --- a/zipline/sources/test_source.py +++ b/zipline/sources/test_source.py @@ -19,10 +19,13 @@ A source to be used in testing. import pytz -from itertools import cycle, ifilter, izip +from itertools import cycle +from six.moves import filter, zip from datetime import datetime, timedelta import numpy as np +from six.moves import range + from zipline.protocol import ( Event, DATASOURCE_TYPE @@ -68,9 +71,9 @@ def date_gen(start=datetime(2006, 6, 6, 12, tzinfo=pytz.utc), # during trading hours. # NB: Being inside of trading hours is currently dependent upon the # count parameter being less than the number of trading minutes in a day - for i in xrange(count): + for i in range(count): if repeats: - for j in xrange(repeats): + for j in range(repeats): yield cur else: yield cur @@ -90,7 +93,7 @@ def mock_prices(count): Utility to generate a stream of mock prices. By default cycles through values from 0.0 to 10.0, n times. """ - return (float(i % 10) + 1.0 for i in xrange(count)) + return (float(i % 10) + 1.0 for i in range(count)) def mock_volumes(count): @@ -98,7 +101,7 @@ def mock_volumes(count): Utility to generate a set of volumes. By default cycles through values from 100 to 1000, incrementing by 50. """ - return ((i * 50) % 900 + 100 for i in xrange(count)) + return ((i * 50) % 900 + 100 for i in range(count)) class SpecificEquityTrades(object): @@ -163,6 +166,9 @@ class SpecificEquityTrades(object): def next(self): return self.generator.next() + def __next__(self): + return next(self.generator) + def rewind(self): self.generator = self.create_fresh_generator() @@ -204,7 +210,7 @@ class SpecificEquityTrades(object): sids = cycle(self.sids) # Combine the iterators into a single iterator of arguments - arg_gen = izip(sids, prices, volumes, dates) + arg_gen = zip(sids, prices, volumes, dates) # Convert argument packages into events. unfiltered = (create_trade(*args, source_id=self.get_hash()) @@ -213,7 +219,7 @@ class SpecificEquityTrades(object): # If we specified a sid filter, filter out elements that don't # match the filter. if self.filter: - filtered = ifilter( + filtered = filter( lambda event: event.sid in self.filter, unfiltered) # Otherwise just use all events. diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index e4a05fdd..ab14f55b 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -74,6 +74,8 @@ The algorithm must expose methods: from copy import deepcopy import numpy as np +from six.moves import range + from zipline.algorithm import TradingAlgorithm from zipline.finance.slippage import FixedSlippage @@ -191,7 +193,7 @@ class TooMuchProcessingAlgorithm(TradingAlgorithm): def handle_data(self, data): # Unless we're running on some sort of # supercomputer this will hit timeout. - for i in xrange(1000000000): + for i in range(1000000000): self.foo = i diff --git a/zipline/transforms/batch_transform.py b/zipline/transforms/batch_transform.py index 41bfa203..02436d83 100644 --- a/zipline/transforms/batch_transform.py +++ b/zipline/transforms/batch_transform.py @@ -26,6 +26,12 @@ from numbers import Integral import pandas as pd +from six import ( + string_types, + itervalues, + iteritems +) + from zipline.utils.data import RollingPanel from zipline.protocol import Event @@ -187,7 +193,7 @@ class BatchTransform(object): # enter the batch transform's window IFF a sid filter is not # specified. if sids is not None: - if isinstance(sids, (basestring, Integral)): + if isinstance(sids, (string_types, Integral)): self.static_sids = set([sids]) else: self.static_sids = set(sids) @@ -195,7 +201,7 @@ class BatchTransform(object): self.static_sids = None self.initial_field_names = fields - if isinstance(self.initial_field_names, basestring): + if isinstance(self.initial_field_names, string_types): self.initial_field_names = [self.initial_field_names] self.field_names = set() @@ -230,7 +236,7 @@ class BatchTransform(object): Point of entry. Process an event frame. """ # extract dates - dts = [event.datetime for event in data._data.itervalues()] + dts = [event.datetime for event in itervalues(data._data)] # we have to provide the event with a dt. This is only for # checking if the event is outside the window or not so a # couple of seconds shouldn't matter. We don't add it to @@ -238,7 +244,7 @@ class BatchTransform(object): # sid keys. event = Event() event.dt = max(dts) - event.data = {k: v.__dict__ for k, v in data._data.iteritems() + event.data = {k: v.__dict__ for k, v in iteritems(data._data) # Need to check if data has a 'length' to filter # out sids without trade data available. # TODO: expose more of 'no trade available' @@ -419,7 +425,7 @@ class BatchTransform(object): # extract field names from sids (price, volume etc), make sure # every sid has the same fields. sid_keys = [] - for sid in event.data.itervalues(): + for sid in itervalues(event.data): keys = set([name for name, value in sid.items() if isinstance(value, (int, diff --git a/zipline/transforms/mavg.py b/zipline/transforms/mavg.py index 6168c23f..dc79bcc1 100644 --- a/zipline/transforms/mavg.py +++ b/zipline/transforms/mavg.py @@ -15,23 +15,24 @@ from collections import defaultdict +from six import string_types, with_metaclass + from zipline.transforms.utils import EventWindow, TransformMeta from zipline.errors import WrongDataForTransform -class MovingAverage(object): +class MovingAverage(with_metaclass(TransformMeta)): """ Class that maintains a dictionary from sids to MovingAverageEventWindows. For each sid, we maintain moving averages over any number of distinct fields (For example, we can maintain a sid's average volume as well as its average price.) """ - __metaclass__ = TransformMeta def __init__(self, fields='price', market_aware=True, window_length=None, delta=None): - if isinstance(fields, basestring): + if isinstance(fields, string_types): fields = [fields] self.fields = fields diff --git a/zipline/transforms/returns.py b/zipline/transforms/returns.py index 15e8664f..be0bb8d9 100644 --- a/zipline/transforms/returns.py +++ b/zipline/transforms/returns.py @@ -17,13 +17,14 @@ from zipline.errors import WrongDataForTransform from zipline.transforms.utils import TransformMeta from collections import defaultdict, deque +from six import with_metaclass -class Returns(object): + +class Returns(with_metaclass(TransformMeta)): """ Class that maintains a dictionary from sids to the sid's closing price N trading days ago. """ - __metaclass__ = TransformMeta def __init__(self, window_length): self.window_length = window_length diff --git a/zipline/transforms/stddev.py b/zipline/transforms/stddev.py index dddfa745..280f98e4 100644 --- a/zipline/transforms/stddev.py +++ b/zipline/transforms/stddev.py @@ -16,19 +16,20 @@ from collections import defaultdict from math import sqrt +from six import with_metaclass + from zipline.errors import WrongDataForTransform from zipline.transforms.utils import EventWindow, TransformMeta import zipline.utils.math_utils as zp_math -class MovingStandardDev(object): +class MovingStandardDev(with_metaclass(TransformMeta)): """ Class that maintains a dictionary from sids to MovingStandardDevWindows. For each sid, we maintain a the standard deviation of all events falling within the specified window. """ - __metaclass__ = TransformMeta def __init__(self, market_aware=True, window_length=None, delta=None): diff --git a/zipline/transforms/ta.py b/zipline/transforms/ta.py index 54752fb8..aa1c2deb 100644 --- a/zipline/transforms/ta.py +++ b/zipline/transforms/ta.py @@ -19,6 +19,9 @@ import numpy as np import pandas as pd import talib import copy + +from six import iteritems + from zipline.transforms import BatchTransform @@ -45,7 +48,7 @@ def zipline_wrapper(talib_fn, key_map, data): for sid in data.minor_axis: # build talib_data from zipline data talib_data = dict() - for talib_key, zipline_key in key_map.iteritems(): + for talib_key, zipline_key in iteritems(key_map): # if zipline_key is found, add it to talib_data if zipline_key in data: values = data[zipline_key][sid].values diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index 8f6d6801..70aed99e 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -17,7 +17,6 @@ """ Generator versions of transforms. """ -import types import logbook @@ -27,6 +26,8 @@ from datetime import datetime from collections import deque from abc import ABCMeta, abstractmethod +from six import with_metaclass + from zipline.protocol import DATASOURCE_TYPE from zipline.gens.utils import assert_sort_unframe_protocol, hash_args from zipline.finance import trading @@ -92,8 +93,6 @@ class StatefulTransform(object): Otherwise only dt, tnfm_id, and tnfm_value are forwarded. """ def __init__(self, tnfm_class, *args, **kwargs): - assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \ - "Stateful transform requires a class." assert hasattr(tnfm_class, 'update'), \ "Stateful transform requires the class to have an update method" @@ -150,7 +149,7 @@ class StatefulTransform(object): yield out_message -class EventWindow(object): +class EventWindow(with_metaclass(ABCMeta)): """ Abstract base class for transform classes that calculate iterative metrics on events within a given timedelta. Maintains a list of @@ -169,7 +168,6 @@ class EventWindow(object): price. """ # Mark this as an abstract base class. - __metaclass__ = ABCMeta def __init__(self, market_aware=True, window_length=None, delta=None): diff --git a/zipline/transforms/vwap.py b/zipline/transforms/vwap.py index 993896c3..48c0b3ba 100644 --- a/zipline/transforms/vwap.py +++ b/zipline/transforms/vwap.py @@ -15,15 +15,16 @@ from collections import defaultdict +from six import with_metaclass + from zipline.errors import WrongDataForTransform from zipline.transforms.utils import EventWindow, TransformMeta -class MovingVWAP(object): +class MovingVWAP(with_metaclass(TransformMeta)): """ Class that maintains a dictionary from sids to VWAPEventWindows. """ - __metaclass__ = TransformMeta def __init__(self, market_aware=True, delta=None, window_length=None): diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index 16b371fe..db0ae182 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -141,7 +141,8 @@ def create_dividend(sid, payment, declared_date, ex_date, pay_date): 'ex_date': ex_date.replace(hour=0, minute=0, second=0, microsecond=0), 'pay_date': pay_date.replace(hour=0, minute=0, second=0, microsecond=0), - 'type': DATASOURCE_TYPE.DIVIDEND + 'type': DATASOURCE_TYPE.DIVIDEND, + 'source_id': 'MockDividendSource' }) return div @@ -152,7 +153,8 @@ def create_split(sid, ratio, date): 'sid': sid, 'ratio': ratio, 'dt': date.replace(hour=0, minute=0, second=0, microsecond=0), - 'type': DATASOURCE_TYPE.SPLIT + 'type': DATASOURCE_TYPE.SPLIT, + 'source_id': 'MockSplitSource' }) @@ -162,7 +164,8 @@ def create_txn(sid, price, amount, datetime): 'amount': amount, 'dt': datetime, 'price': price, - 'type': DATASOURCE_TYPE.TRANSACTION + 'type': DATASOURCE_TYPE.TRANSACTION, + 'source_id': 'MockTransactionSource' }) return txn @@ -172,7 +175,8 @@ def create_commission(sid, value, datetime): 'dt': datetime, 'type': DATASOURCE_TYPE.COMMISSION, 'cost': value, - 'sid': sid + 'sid': sid, + 'source_id': 'MockCommissionSource' }) return txn diff --git a/zipline/utils/test_utils.py b/zipline/utils/test_utils.py index a4c679f8..9f348c3b 100644 --- a/zipline/utils/test_utils.py +++ b/zipline/utils/test_utils.py @@ -1,6 +1,8 @@ from logbook import FileHandler from zipline.finance.blotter import ORDER_STATUS +from six import itervalues + def setup_logger(test, path='test.log'): test.log_handler = FileHandler(path) @@ -57,7 +59,7 @@ def assert_single_position(test, zipline): for order in update['daily_perf']['orders']: orders_by_id[order['id']] = order - for order in orders_by_id.itervalues(): + for order in itervalues(orders_by_id): test.assertEqual( order['status'], ORDER_STATUS.FILLED, @@ -93,6 +95,9 @@ class ExceptionSource(object): def next(self): 5 / 0 + def __next__(self): + 5 / 0 + class ExceptionTransform(object):