Merge Python 3 compatibility branch.

This branch should make it so that the code base can be run
using both Python 2.7 and Python 3.3
This commit is contained in:
Eddie Hebert
2014-01-07 15:34:04 -05:00
35 changed files with 193 additions and 142 deletions
+1
View File
@@ -1,6 +1,7 @@
language: python
python:
- "2.7"
- "3.3"
before_install:
- wget -O ta-lib-0.4.0-src.tar.gz http://sourceforge.net/projects/ta-lib/files/ta-lib/0.4.0/ta-lib-0.4.0-src.tar.gz/download
- tar xvzf ta-lib-0.4.0-src.tar.gz
+1 -1
View File
@@ -329,7 +329,7 @@ class SlippageTestCase(TestCase):
slippage_model = VolumeShareSlippage()
try:
_, txn = slippage_model.simulate(event, [order]).next()
_, txn = next(slippage_model.simulate(event, [order]))
except StopIteration:
txn = None
+8 -5
View File
@@ -22,6 +22,8 @@ import pytz
import xlrd
import requests
from six.moves import map
def col_letter_to_index(col_letter):
# Only supports single letter,
@@ -51,12 +53,12 @@ LATEST_ANSWER_KEY_URL = ANSWER_KEY_DL_TEMPLATE.format(
def answer_key_signature():
with open(ANSWER_KEY_PATH, 'r') as f:
with open(ANSWER_KEY_PATH, 'rb') as f:
md5 = hashlib.md5()
while True:
buf = f.read(1024)
md5.update(buf)
while buf != b"":
buf = f.read(1024)
if not buf:
break
md5.update(buf)
return md5.hexdigest()
@@ -288,7 +290,8 @@ class AnswerKey(object):
def get_values(self, data_index):
value_parser = self.value_type_to_value_func[data_index.value_type]
return map(value_parser, self.get_raw_values(data_index))
return [value for value in
map(value_parser, self.get_raw_values(data_index))]
ANSWER_KEY = AnswerKey()
+1 -1
View File
@@ -23,7 +23,7 @@ from zipline.utils import factory
from zipline.finance.trading import SimulationParameters
import answer_key
from . import answer_key
ANSWER_KEY = answer_key.ANSWER_KEY
+3 -3
View File
@@ -293,7 +293,7 @@ class TestEventsThroughRisk(unittest.TestCase):
crm = algo.perf_tracker.cumulative_risk_metrics
first_msg = gen.next()
first_msg = next(gen)
self.assertIsNotNone(first_msg,
"There should be a message emitted.")
@@ -310,7 +310,7 @@ class TestEventsThroughRisk(unittest.TestCase):
crm.metrics.algorithm_volatility[algo.datetime.date()],
"On the first day algorithm volatility does not exist.")
second_msg = gen.next()
second_msg = next(gen)
self.assertIsNotNone(second_msg, "There should be a message "
"emitted.")
@@ -325,7 +325,7 @@ class TestEventsThroughRisk(unittest.TestCase):
crm.algorithm_returns[-1],
decimal=6)
third_msg = gen.next()
third_msg = next(gen)
self.assertEqual(1, len(algo.portfolio.positions),
"Number of positions should stay the same.")
+11 -3
View File
@@ -21,15 +21,22 @@
import matplotlib
matplotlib.use('Agg')
from os import path
import os
from os import path
try:
from path import walk
except ImportError:
# Assume Python 3
from os import walk
import fnmatch
import imp
def test_examples():
os.chdir(example_dir())
for fname in all_matching_files('.', '*.py'):
for fname in all_matching_files(example_dir(), '*.py'):
yield check_example, fname
@@ -40,7 +47,8 @@ def all_matching_files(d, pattern):
fls.extend(nfiles)
files = []
path.walk(d, addfiles, files)
for dirpath, dirnames, filenames in walk(d):
addfiles(files, dirpath, filenames)
return files
+5 -16
View File
@@ -57,14 +57,9 @@ class ExceptionTestCase(TestCase):
**self.zipline_test_config
)
with self.assertRaises(ZeroDivisionError) as ctx:
with self.assertRaises(ZeroDivisionError):
output, _ = drain_zipline(self, zipline)
self.assertEqual(
ctx.exception.message,
'integer division or modulo by zero'
)
def test_tranform_exception(self):
exc_tnfm = StatefulTransform(ExceptionTransform)
self.zipline_test_config['transforms'] = [exc_tnfm]
@@ -76,7 +71,7 @@ class ExceptionTestCase(TestCase):
with self.assertRaises(AssertionError) as ctx:
output, _ = drain_zipline(self, zipline)
self.assertEqual(ctx.exception.message,
self.assertEqual(str(ctx.exception),
'An assertion message')
def test_exception_in_handle_data(self):
@@ -96,7 +91,7 @@ class ExceptionTestCase(TestCase):
with self.assertRaises(Exception) as ctx:
output, _ = drain_zipline(self, zipline)
self.assertEqual(ctx.exception.message,
self.assertEqual(str(ctx.exception),
'Algo exception in handle_data')
def test_zerodivision_exception_in_handle_data(self):
@@ -113,12 +108,9 @@ class ExceptionTestCase(TestCase):
**self.zipline_test_config
)
with self.assertRaises(ZeroDivisionError) as ctx:
with self.assertRaises(ZeroDivisionError):
output, _ = drain_zipline(self, zipline)
self.assertEqual(ctx.exception.message,
'integer division or modulo by zero')
def test_set_portfolio(self):
"""
Are we protected against overwriting an algo's portfolio?
@@ -136,8 +128,5 @@ class ExceptionTestCase(TestCase):
**self.zipline_test_config
)
with self.assertRaises(AttributeError) as ctx:
with self.assertRaises(AttributeError):
output, _ = drain_zipline(self, zipline)
self.assertEqual(ctx.exception.message,
"can't set attribute")
+5 -3
View File
@@ -28,6 +28,8 @@ import numpy as np
from nose.tools import timed
from six.moves import range
import zipline.protocol
from zipline.protocol import Event, DATASOURCE_TYPE
@@ -314,7 +316,7 @@ class FinanceTestCase(TestCase):
alternator = 1
order_date = start_date
for i in xrange(order_count):
for i in range(order_count):
blotter.set_date(order_date)
blotter.order(sid, order_amount * alternator ** i, None, None)
@@ -334,7 +336,7 @@ class FinanceTestCase(TestCase):
order_list = oo[sid]
self.assertEqual(order_count, len(order_list))
for i in xrange(order_count):
for i in range(order_count):
order = order_list[i]
self.assertEqual(order.sid, sid)
self.assertEqual(order.amount, order_amount * alternator ** i)
@@ -372,7 +374,7 @@ class FinanceTestCase(TestCase):
self.assertEqual(len(transactions), len(order_list))
total_volume = 0
for i in xrange(len(transactions)):
for i in range(len(transactions)):
txn = transactions[i]
total_volume += txn.amount
if complete_fill:
+13 -11
View File
@@ -24,6 +24,8 @@ import datetime
import pytz
import itertools
from six.moves import range
import zipline.utils.factory as factory
import zipline.finance.performance as perf
from zipline.finance.slippage import Transaction, create_transaction
@@ -47,7 +49,9 @@ tradingday = datetime.timedelta(hours=6, minutes=30)
def create_txn(event, price, amount):
mock_order = Order(None, None, event.sid, id=None)
return create_transaction(event, mock_order, price, amount)
txn = create_transaction(event, mock_order, price, amount)
txn.source_id = 'MockTransactionSource'
return txn
def benchmark_events_in_range(sim_params):
@@ -67,12 +71,10 @@ def calculate_results(host, events):
perf_tracker = perf.PerformanceTracker(host.sim_params)
all_events = heapq.merge(
((event.dt, event) for event in events),
((event.dt, event) for event in host.benchmark_events))
all_events = date_sorted_sources(events, host.benchmark_events)
filtered_events = [(date, filt_event) for (date, filt_event)
in all_events if date <= events[-1].dt]
filtered_events = [(filt_event.dt, filt_event) for filt_event
in all_events if filt_event.dt <= events[-1].dt]
filtered_events.sort(key=lambda x: x[0])
grouped_events = itertools.groupby(filtered_events, lambda x: x[0])
results = []
@@ -431,7 +433,7 @@ class TestDividendPerformance(unittest.TestCase):
pay_date = self.sim_params.first_open
# find pay date that is much later.
for i in xrange(30):
for i in range(30):
pay_date = factory.get_next_trading_dt(pay_date, oneday)
dividend = factory.create_dividend(
1,
@@ -1133,12 +1135,10 @@ class TestPerformanceTracker(unittest.TestCase):
orders = [event for event in
events if event.type == DATASOURCE_TYPE.ORDER]
all_events = (msg[1] for msg in heapq.merge(
((event.dt, event) for event in events),
((event.dt, event) for event in benchmark_events)))
all_events = date_sorted_sources(events, benchmark_events)
filtered_events = [filt_event for filt_event
in all_events if event.dt <= end_dt]
in all_events if filt_event.dt <= end_dt]
filtered_events.sort(key=lambda x: x.dt)
grouped_events = itertools.groupby(filtered_events, lambda x: x.dt)
perf_messages = []
@@ -1170,6 +1170,7 @@ class TestPerformanceTracker(unittest.TestCase):
amount=-25,
dt=event.dt
)
order.source_id = 'MockOrderSource'
yield order
yield event
txn = Transaction(
@@ -1180,6 +1181,7 @@ class TestPerformanceTracker(unittest.TestCase):
commission=0.50,
order_id=order.id
)
txn.source_id = 'MockTransactionSource'
yield txn
else:
yield event
+5 -3
View File
@@ -16,6 +16,8 @@ import pandas as pd
import pytz
from itertools import cycle
from six import integer_types
from unittest import TestCase
import zipline.utils.factory as factory
@@ -29,7 +31,7 @@ class TestDataFrameSource(TestCase):
assert isinstance(source.end, pd.lib.Timestamp)
for expected_dt, expected_price in df.iterrows():
sid0 = source.next()
sid0 = next(source)
assert expected_dt == sid0.dt
assert expected_price[0] == sid0.price
@@ -71,5 +73,5 @@ class TestDataFrameSource(TestCase):
for event in source:
for check_field in check_fields:
self.assertIn(check_field, event)
self.assertTrue(isinstance(event['volume'], (int, long)))
self.assertEqual(stocks_iter.next(), event['sid'])
self.assertTrue(isinstance(event['volume'], (integer_types)))
self.assertEqual(next(stocks_iter), event['sid'])
+4 -2
View File
@@ -20,6 +20,8 @@ import pandas as pd
from datetime import timedelta, datetime
from unittest import TestCase
from six.moves import range
from zipline.utils.test_utils import setup_logger
from zipline.protocol import Event
@@ -64,7 +66,7 @@ class TestEventWindow(TestCase):
self.monday = datetime(2012, 7, 9, 16, tzinfo=pytz.utc)
self.eleven_normal_days = [self.monday + i * timedelta(days=1)
for i in xrange(11)]
for i in range(11)]
# Modify the end of the period slightly to exercise the
# incomplete day logic.
@@ -75,7 +77,7 @@ class TestEventWindow(TestCase):
# Second set of dates to test holiday handling.
self.jul4_monday = datetime(2012, 7, 2, 16, tzinfo=pytz.utc)
self.week_of_jul4 = [self.jul4_monday + i * timedelta(days=1)
for i in xrange(5)]
for i in range(5)]
def test_market_aware_window_normal_week(self):
window = NoopEventWindow(
+5 -3
View File
@@ -20,7 +20,9 @@ import numpy as np
from datetime import datetime
from itertools import groupby, ifilter
from itertools import groupby
from six.moves import filter
from six import iteritems
from operator import attrgetter
from zipline.errors import (
@@ -194,7 +196,7 @@ class TradingAlgorithm(object):
date_sorted = date_sorted_sources(*self.sources)
if source_filter:
date_sorted = ifilter(source_filter, date_sorted)
date_sorted = filter(source_filter, date_sorted)
with_tnfms = sequential_transforms(date_sorted,
*self.transforms)
@@ -305,7 +307,7 @@ class TradingAlgorithm(object):
# Create transforms by wrapping them into StatefulTransforms
self.transforms = []
for namestring, trans_descr in self.registered_transforms.iteritems():
for namestring, trans_descr in iteritems(self.registered_transforms):
sf = StatefulTransform(
trans_descr['class'],
*trans_descr['args'],
+4 -2
View File
@@ -23,6 +23,8 @@ from functools import partial
import requests
import pandas as pd
from six import iteritems
from . loader_utils import (
date_conversion,
source_to_records,
@@ -50,7 +52,7 @@ _BENCHMARK_MAPPING = {
def benchmark_mappings():
return {key: Mapping(*value)
for key, value
in _BENCHMARK_MAPPING.iteritems()}
in iteritems(_BENCHMARK_MAPPING)}
def get_raw_benchmark_data(start_date, end_date, symbol):
@@ -86,7 +88,7 @@ start_date={start_date}, end_date={end_date}, url={url}""".strip().
end_date=end_date,
url=res.url))
return csv.DictReader(res.iter_lines())
return csv.DictReader(res.text.splitlines())
def get_benchmark_data(symbol, start_date=None, end_date=None):
+21 -27
View File
@@ -26,6 +26,8 @@ import pandas as pd
from pandas.io.data import DataReader
import pytz
from six import iteritems
from . import benchmarks
from . benchmarks import get_benchmark_returns
@@ -60,7 +62,7 @@ INDEX_MAPPING = {
}
def get_datafile(name, mode='r'):
def get_data_filepath(name):
"""
Returns a handle to data file.
@@ -70,7 +72,7 @@ def get_datafile(name, mode='r'):
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
return open(os.path.join(DATA_PATH, name), mode)
return os.path.join(DATA_PATH, name)
def get_cache_filepath(name):
@@ -100,9 +102,8 @@ def dump_treasury_curves(module='treasuries', filename='treasury_curves.csv'):
curves = pd.DataFrame(tr_data).T
datafile = get_datafile(filename, mode='wb')
curves.to_csv(datafile)
datafile.close()
data_filepath = get_data_filepath(filename)
curves.to_csv(data_filepath)
return curves
@@ -119,10 +120,9 @@ def dump_benchmarks(symbol):
benchmark = (daily_return.date, daily_return.returns)
benchmark_data.append(benchmark)
datafile = get_datafile(get_benchmark_filename(symbol), mode='wb')
data_filepath = get_data_filepath(get_benchmark_filename(symbol))
benchmark_returns = pd.Series(dict(benchmark_data))
benchmark_returns.to_csv(datafile)
datafile.close()
benchmark_returns.to_csv(data_filepath)
def update_benchmarks(symbol, last_date):
@@ -133,9 +133,8 @@ def update_benchmarks(symbol, last_date):
Puts source benchmark into zipline.
"""
datafile = get_datafile(get_benchmark_filename(symbol), mode='rb')
datafile = get_data_filepath(get_benchmark_filename(symbol))
saved_benchmarks = pd.Series.from_csv(datafile)
datafile.close()
try:
start = last_date + timedelta(days=1)
@@ -144,9 +143,8 @@ def update_benchmarks(symbol, last_date):
benchmark = pd.Series({daily_return.date: daily_return.returns})
saved_benchmarks = saved_benchmarks.append(benchmark)
datafile = get_datafile(get_benchmark_filename(symbol), mode='wb')
datafile = get_data_filepath(get_benchmark_filename(symbol))
saved_benchmarks.to_csv(datafile)
datafile.close()
except benchmarks.BenchmarkDataNotFoundError as exc:
logger.warn(exc)
return saved_benchmarks
@@ -157,19 +155,18 @@ def get_benchmark_filename(symbol):
def load_market_data(bm_symbol='^GSPC'):
bm_filepath = get_data_filepath(get_benchmark_filename(bm_symbol))
try:
fp_bm = get_datafile(get_benchmark_filename(bm_symbol), "rb")
except IOError:
saved_benchmarks = pd.Series.from_csv(bm_filepath)
except OSError:
print("""
data files aren't distributed with source.
Fetching data from Yahoo Finance.
""".strip())
dump_benchmarks(bm_symbol)
fp_bm = get_datafile(get_benchmark_filename(bm_symbol), "rb")
saved_benchmarks = pd.Series.from_csv(bm_filepath)
saved_benchmarks = pd.Series.from_csv(fp_bm)
saved_benchmarks = saved_benchmarks.tz_localize('UTC')
fp_bm.close()
most_recent = pd.Timestamp('today', tz='UTC') - trading_day
most_recent_index = trading_days.searchsorted(most_recent)
@@ -205,17 +202,16 @@ Fetching data from Yahoo Finance.
module, filename, source = INDEX_MAPPING.get(
bm_symbol, INDEX_MAPPING['^GSPC'])
tr_filepath = get_data_filepath(filename)
try:
fp_tr = get_datafile(filename, "rb")
except IOError:
saved_curves = pd.DataFrame.from_csv(tr_filepath)
except OSError:
print("""
data files aren't distributed with source.
Fetching data from {0}
""".format(source).strip())
dump_treasury_curves(module, filename)
fp_tr = get_datafile(filename, "rb")
saved_curves = pd.DataFrame.from_csv(fp_tr)
saved_curves = pd.DataFrame.from_csv(tr_filepath)
# Find the offset of the last date for which we have trading data in our
# list of valid trading days
@@ -236,10 +232,8 @@ Fetching data from {0}
# tzinfo=pytz.utc)
tr_curves[tr_dt] = curve.to_dict()
fp_tr.close()
tr_curves = OrderedDict(sorted(
((dt, c) for dt, c in tr_curves.iteritems()),
((dt, c) for dt, c in iteritems(tr_curves)),
key=lambda t: t[0]))
return benchmark_returns, tr_curves
@@ -291,7 +285,7 @@ must specify stocks or indexes"""
data[stock] = stkd
if indexes is not None:
for name, ticker in indexes.iteritems():
for name, ticker in iteritems(indexes):
print(name)
stkd = DataReader(ticker, 'yahoo', start, end).sort_index()
data[name] = stkd
@@ -327,7 +321,7 @@ def load_from_yahoo(indexes=None,
close_key = 'Adj Close'
else:
close_key = 'Close'
df = pd.DataFrame({key: d[close_key] for key, d in data.iteritems()})
df = pd.DataFrame({key: d[close_key] for key, d in iteritems(data)})
df.index = df.index.tz_localize(pytz.utc)
return df
+3 -1
View File
@@ -30,6 +30,8 @@ from collections import namedtuple
from functools import partial
from six import iteritems
def get_utc_from_exchange_time(naive):
local = pytz.timezone('US/Eastern')
@@ -126,7 +128,7 @@ def _row_cb(mapping, row):
return {
target: apply_mapping(mapping, row)
for target, mapping
in mapping.iteritems()
in iteritems(mapping)
}
+4 -2
View File
@@ -21,6 +21,8 @@ import requests
from collections import OrderedDict
import xml.etree.ElementTree as ET
from six import iteritems
from . loader_utils import (
guarded_conversion,
safe_int,
@@ -61,7 +63,7 @@ _CURVE_MAPPINGS = {
def treasury_mappings(mappings):
return {key: Mapping(*value)
for key, value
in mappings.iteritems()}
in iteritems(mappings)}
class iter_to_stream(object):
@@ -96,7 +98,7 @@ def get_treasury_source():
http://data.treasury.gov/feed.svc/DailyTreasuryYieldCurveRateData\
"""
res = requests.get(url, stream=True)
stream = iter_to_stream(res.iter_lines())
stream = iter_to_stream(res.text.splitlines())
elements = ET.iterparse(stream, ('end', 'start-ns', 'end-ns'))
+8 -6
View File
@@ -77,6 +77,8 @@ import numpy as np
import pandas as pd
from collections import OrderedDict, defaultdict
from six import iteritems, itervalues
import zipline.protocol as zp
from . position import positiondict
@@ -164,7 +166,7 @@ class PerformancePeriod(object):
payment has been disbursed.
"""
cash_payments = 0.0
for sid, pos in self.positions.iteritems():
for sid, pos in iteritems(self.positions):
cash_payments += pos.update_dividends(todays_date)
# credit our cash balance with the dividend payments, or
@@ -307,7 +309,7 @@ class PerformancePeriod(object):
else:
transactions = \
[y.to_dict()
for x in self.processed_transactions.itervalues()
for x in itervalues(self.processed_transactions)
for y in x]
rval['transactions'] = transactions
@@ -315,9 +317,9 @@ class PerformancePeriod(object):
if dt:
# only include orders modified as of the given dt.
orders = [x.to_dict()
for x in self.orders_by_modified[dt].itervalues()]
for x in itervalues(self.orders_by_modified[dt])]
else:
orders = [x.to_dict() for x in self.orders_by_id.itervalues()]
orders = [x.to_dict() for x in itervalues(self.orders_by_id)]
rval['orders'] = orders
return rval
@@ -352,7 +354,7 @@ class PerformancePeriod(object):
positions = self._positions_store
for sid, pos in self.positions.iteritems():
for sid, pos in iteritems(self.positions):
if sid not in positions:
positions[sid] = zp.Position(sid)
position = positions[sid]
@@ -364,7 +366,7 @@ class PerformancePeriod(object):
def get_positions_list(self):
positions = []
for sid, pos in self.positions.iteritems():
for sid, pos in iteritems(self.positions):
if pos.amount != 0:
positions.append(pos.to_dict())
return positions
+3 -1
View File
@@ -24,6 +24,8 @@ import zipline.utils.math_utils as zp_math
import pandas as pd
from pandas.tseries.tools import normalize_date
from six import iteritems
from . risk import (
alpha,
check_entry,
@@ -359,7 +361,7 @@ algorithm_returns ({algo_count}) in range {start} : {end} on {dt}"
return {k: None
if check_entry(k, v)
else v for k, v in rval.iteritems()}
else v for k, v in iteritems(rval)}
def __repr__(self):
statements = []
+4 -2
View File
@@ -20,11 +20,13 @@ import math
import numpy as np
import numpy.linalg as la
from six import iteritems
from zipline.finance import trading
import pandas as pd
import risk
from . import risk
from . risk import (
alpha,
check_entry,
@@ -131,7 +133,7 @@ class RiskMetricsPeriod(object):
}
return {k: None if check_entry(k, v) else v
for k, v in rval.iteritems()}
for k, v in iteritems(rval)}
def __repr__(self):
statements = []
+4 -3
View File
@@ -20,6 +20,9 @@ import math
from copy import copy
from functools import partial
from six import with_metaclass
from zipline.protocol import DATASOURCE_TYPE
import zipline.utils.math_utils as zp_math
@@ -152,9 +155,7 @@ def create_transaction(event, order, price, amount):
return transaction
class SlippageModel(object):
__metaclass__ = abc.ABCMeta
class SlippageModel(with_metaclass(abc.ABCMeta)):
@property
def volume_for_bar(self):
+2
View File
@@ -15,6 +15,8 @@
import heapq
from six.moves import reduce
def _decorate_source(source):
for message in source:
+4 -6
View File
@@ -21,23 +21,24 @@ from hashlib import md5
from datetime import datetime
from zipline.protocol import DATASOURCE_TYPE
from six import iteritems, b
def hash_args(*args, **kwargs):
"""Define a unique string for any set of representable args."""
arg_string = '_'.join([str(arg) for arg in args])
kwarg_string = '_'.join([str(key) + '=' + str(value)
for key, value in kwargs.iteritems()])
for key, value in iteritems(kwargs)])
combined = ':'.join([arg_string, kwarg_string])
hasher = md5()
hasher.update(combined)
hasher.update(b(combined))
return hasher.hexdigest()
def assert_datasource_protocol(event):
"""Assert that an event meets the protocol for datasource outputs."""
assert isinstance(event.source_id, basestring)
assert event.type in DATASOURCE_TYPE
# Done packets have no dt.
@@ -59,17 +60,14 @@ def assert_trade_protocol(event):
def assert_datasource_unframe_protocol(event):
"""Assert that an event is valid output of zp.DATASOURCE_UNFRAME."""
assert isinstance(event.source_id, basestring)
assert event.type in DATASOURCE_TYPE
def assert_sort_protocol(event):
"""Assert that an event is valid input to zp.FEED_FRAME."""
assert isinstance(event.source_id, basestring)
assert event.type in DATASOURCE_TYPE
def assert_sort_unframe_protocol(event):
"""Same as above."""
assert isinstance(event.source_id, basestring)
assert event.type in DATASOURCE_TYPE
+7 -5
View File
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from six import iteritems, iterkeys
from . utils.protocol_utils import Enum
# Datasource type should completely determine the other fields of a
@@ -171,7 +173,7 @@ class BarData(object):
del self._data[name]
def __iter__(self):
for sid, data in self._data.iteritems():
for sid, data in iteritems(self._data):
# Allow contains override to filter out sids.
if sid in self:
if len(data):
@@ -179,25 +181,25 @@ class BarData(object):
def iterkeys(self):
# Allow contains override to filter out sids.
return (sid for sid in self._data.iterkeys() if sid in self)
return (sid for sid in iterkeys(self._data) if sid in self)
def keys(self):
# Allow contains override to filter out sids.
return list(self.iterkeys())
def itervalues(self):
return (value for sid, value in self.iteritems())
return (value for sid, value in iteritems(self))
def values(self):
return list(self.itervalues())
def iteritems(self):
return ((sid, value) for sid, value
in self._data.iteritems()
in iteritems(self._data)
if sid in self)
def items(self):
return list(self.iteritems())
return list(iteritems(self))
def __len__(self):
return len(self.keys())
+6 -3
View File
@@ -3,13 +3,13 @@ from abc import (
abstractproperty
)
from six import with_metaclass
from zipline.protocol import DATASOURCE_TYPE
from zipline.protocol import Event
class DataSource(object):
__metaclass__ = ABCMeta
class DataSource(with_metaclass(ABCMeta)):
@property
def event_type(self):
@@ -62,3 +62,6 @@ class DataSource(object):
def next(self):
return self.mapped_data.next()
def __next__(self):
return next(self.mapped_data)
+13 -7
View File
@@ -19,10 +19,13 @@ A source to be used in testing.
import pytz
from itertools import cycle, ifilter, izip
from itertools import cycle
from six.moves import filter, zip
from datetime import datetime, timedelta
import numpy as np
from six.moves import range
from zipline.protocol import (
Event,
DATASOURCE_TYPE
@@ -68,9 +71,9 @@ def date_gen(start=datetime(2006, 6, 6, 12, tzinfo=pytz.utc),
# during trading hours.
# NB: Being inside of trading hours is currently dependent upon the
# count parameter being less than the number of trading minutes in a day
for i in xrange(count):
for i in range(count):
if repeats:
for j in xrange(repeats):
for j in range(repeats):
yield cur
else:
yield cur
@@ -90,7 +93,7 @@ def mock_prices(count):
Utility to generate a stream of mock prices. By default
cycles through values from 0.0 to 10.0, n times.
"""
return (float(i % 10) + 1.0 for i in xrange(count))
return (float(i % 10) + 1.0 for i in range(count))
def mock_volumes(count):
@@ -98,7 +101,7 @@ def mock_volumes(count):
Utility to generate a set of volumes. By default cycles
through values from 100 to 1000, incrementing by 50.
"""
return ((i * 50) % 900 + 100 for i in xrange(count))
return ((i * 50) % 900 + 100 for i in range(count))
class SpecificEquityTrades(object):
@@ -163,6 +166,9 @@ class SpecificEquityTrades(object):
def next(self):
return self.generator.next()
def __next__(self):
return next(self.generator)
def rewind(self):
self.generator = self.create_fresh_generator()
@@ -204,7 +210,7 @@ class SpecificEquityTrades(object):
sids = cycle(self.sids)
# Combine the iterators into a single iterator of arguments
arg_gen = izip(sids, prices, volumes, dates)
arg_gen = zip(sids, prices, volumes, dates)
# Convert argument packages into events.
unfiltered = (create_trade(*args, source_id=self.get_hash())
@@ -213,7 +219,7 @@ class SpecificEquityTrades(object):
# If we specified a sid filter, filter out elements that don't
# match the filter.
if self.filter:
filtered = ifilter(
filtered = filter(
lambda event: event.sid in self.filter, unfiltered)
# Otherwise just use all events.
+3 -1
View File
@@ -74,6 +74,8 @@ The algorithm must expose methods:
from copy import deepcopy
import numpy as np
from six.moves import range
from zipline.algorithm import TradingAlgorithm
from zipline.finance.slippage import FixedSlippage
@@ -191,7 +193,7 @@ class TooMuchProcessingAlgorithm(TradingAlgorithm):
def handle_data(self, data):
# Unless we're running on some sort of
# supercomputer this will hit timeout.
for i in xrange(1000000000):
for i in range(1000000000):
self.foo = i
+11 -5
View File
@@ -26,6 +26,12 @@ from numbers import Integral
import pandas as pd
from six import (
string_types,
itervalues,
iteritems
)
from zipline.utils.data import RollingPanel
from zipline.protocol import Event
@@ -187,7 +193,7 @@ class BatchTransform(object):
# enter the batch transform's window IFF a sid filter is not
# specified.
if sids is not None:
if isinstance(sids, (basestring, Integral)):
if isinstance(sids, (string_types, Integral)):
self.static_sids = set([sids])
else:
self.static_sids = set(sids)
@@ -195,7 +201,7 @@ class BatchTransform(object):
self.static_sids = None
self.initial_field_names = fields
if isinstance(self.initial_field_names, basestring):
if isinstance(self.initial_field_names, string_types):
self.initial_field_names = [self.initial_field_names]
self.field_names = set()
@@ -230,7 +236,7 @@ class BatchTransform(object):
Point of entry. Process an event frame.
"""
# extract dates
dts = [event.datetime for event in data._data.itervalues()]
dts = [event.datetime for event in itervalues(data._data)]
# we have to provide the event with a dt. This is only for
# checking if the event is outside the window or not so a
# couple of seconds shouldn't matter. We don't add it to
@@ -238,7 +244,7 @@ class BatchTransform(object):
# sid keys.
event = Event()
event.dt = max(dts)
event.data = {k: v.__dict__ for k, v in data._data.iteritems()
event.data = {k: v.__dict__ for k, v in iteritems(data._data)
# Need to check if data has a 'length' to filter
# out sids without trade data available.
# TODO: expose more of 'no trade available'
@@ -419,7 +425,7 @@ class BatchTransform(object):
# extract field names from sids (price, volume etc), make sure
# every sid has the same fields.
sid_keys = []
for sid in event.data.itervalues():
for sid in itervalues(event.data):
keys = set([name for name, value in sid.items()
if isinstance(value,
(int,
+4 -3
View File
@@ -15,23 +15,24 @@
from collections import defaultdict
from six import string_types, with_metaclass
from zipline.transforms.utils import EventWindow, TransformMeta
from zipline.errors import WrongDataForTransform
class MovingAverage(object):
class MovingAverage(with_metaclass(TransformMeta)):
"""
Class that maintains a dictionary from sids to
MovingAverageEventWindows. For each sid, we maintain moving
averages over any number of distinct fields (For example, we can
maintain a sid's average volume as well as its average price.)
"""
__metaclass__ = TransformMeta
def __init__(self, fields='price',
market_aware=True, window_length=None, delta=None):
if isinstance(fields, basestring):
if isinstance(fields, string_types):
fields = [fields]
self.fields = fields
+3 -2
View File
@@ -17,13 +17,14 @@ from zipline.errors import WrongDataForTransform
from zipline.transforms.utils import TransformMeta
from collections import defaultdict, deque
from six import with_metaclass
class Returns(object):
class Returns(with_metaclass(TransformMeta)):
"""
Class that maintains a dictionary from sids to the sid's
closing price N trading days ago.
"""
__metaclass__ = TransformMeta
def __init__(self, window_length):
self.window_length = window_length
+3 -2
View File
@@ -16,19 +16,20 @@
from collections import defaultdict
from math import sqrt
from six import with_metaclass
from zipline.errors import WrongDataForTransform
from zipline.transforms.utils import EventWindow, TransformMeta
import zipline.utils.math_utils as zp_math
class MovingStandardDev(object):
class MovingStandardDev(with_metaclass(TransformMeta)):
"""
Class that maintains a dictionary from sids to
MovingStandardDevWindows. For each sid, we maintain a the
standard deviation of all events falling within the specified
window.
"""
__metaclass__ = TransformMeta
def __init__(self, market_aware=True, window_length=None, delta=None):
+4 -1
View File
@@ -19,6 +19,9 @@ import numpy as np
import pandas as pd
import talib
import copy
from six import iteritems
from zipline.transforms import BatchTransform
@@ -45,7 +48,7 @@ def zipline_wrapper(talib_fn, key_map, data):
for sid in data.minor_axis:
# build talib_data from zipline data
talib_data = dict()
for talib_key, zipline_key in key_map.iteritems():
for talib_key, zipline_key in iteritems(key_map):
# if zipline_key is found, add it to talib_data
if zipline_key in data:
values = data[zipline_key][sid].values
+3 -5
View File
@@ -17,7 +17,6 @@
"""
Generator versions of transforms.
"""
import types
import logbook
@@ -27,6 +26,8 @@ from datetime import datetime
from collections import deque
from abc import ABCMeta, abstractmethod
from six import with_metaclass
from zipline.protocol import DATASOURCE_TYPE
from zipline.gens.utils import assert_sort_unframe_protocol, hash_args
from zipline.finance import trading
@@ -92,8 +93,6 @@ class StatefulTransform(object):
Otherwise only dt, tnfm_id, and tnfm_value are forwarded.
"""
def __init__(self, tnfm_class, *args, **kwargs):
assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \
"Stateful transform requires a class."
assert hasattr(tnfm_class, 'update'), \
"Stateful transform requires the class to have an update method"
@@ -150,7 +149,7 @@ class StatefulTransform(object):
yield out_message
class EventWindow(object):
class EventWindow(with_metaclass(ABCMeta)):
"""
Abstract base class for transform classes that calculate iterative
metrics on events within a given timedelta. Maintains a list of
@@ -169,7 +168,6 @@ class EventWindow(object):
price.
"""
# Mark this as an abstract base class.
__metaclass__ = ABCMeta
def __init__(self, market_aware=True, window_length=None, delta=None):
+3 -2
View File
@@ -15,15 +15,16 @@
from collections import defaultdict
from six import with_metaclass
from zipline.errors import WrongDataForTransform
from zipline.transforms.utils import EventWindow, TransformMeta
class MovingVWAP(object):
class MovingVWAP(with_metaclass(TransformMeta)):
"""
Class that maintains a dictionary from sids to VWAPEventWindows.
"""
__metaclass__ = TransformMeta
def __init__(self, market_aware=True, delta=None, window_length=None):
+8 -4
View File
@@ -141,7 +141,8 @@ def create_dividend(sid, payment, declared_date, ex_date, pay_date):
'ex_date': ex_date.replace(hour=0, minute=0, second=0, microsecond=0),
'pay_date': pay_date.replace(hour=0, minute=0, second=0,
microsecond=0),
'type': DATASOURCE_TYPE.DIVIDEND
'type': DATASOURCE_TYPE.DIVIDEND,
'source_id': 'MockDividendSource'
})
return div
@@ -152,7 +153,8 @@ def create_split(sid, ratio, date):
'sid': sid,
'ratio': ratio,
'dt': date.replace(hour=0, minute=0, second=0, microsecond=0),
'type': DATASOURCE_TYPE.SPLIT
'type': DATASOURCE_TYPE.SPLIT,
'source_id': 'MockSplitSource'
})
@@ -162,7 +164,8 @@ def create_txn(sid, price, amount, datetime):
'amount': amount,
'dt': datetime,
'price': price,
'type': DATASOURCE_TYPE.TRANSACTION
'type': DATASOURCE_TYPE.TRANSACTION,
'source_id': 'MockTransactionSource'
})
return txn
@@ -172,7 +175,8 @@ def create_commission(sid, value, datetime):
'dt': datetime,
'type': DATASOURCE_TYPE.COMMISSION,
'cost': value,
'sid': sid
'sid': sid,
'source_id': 'MockCommissionSource'
})
return txn
+6 -1
View File
@@ -1,6 +1,8 @@
from logbook import FileHandler
from zipline.finance.blotter import ORDER_STATUS
from six import itervalues
def setup_logger(test, path='test.log'):
test.log_handler = FileHandler(path)
@@ -57,7 +59,7 @@ def assert_single_position(test, zipline):
for order in update['daily_perf']['orders']:
orders_by_id[order['id']] = order
for order in orders_by_id.itervalues():
for order in itervalues(orders_by_id):
test.assertEqual(
order['status'],
ORDER_STATUS.FILLED,
@@ -93,6 +95,9 @@ class ExceptionSource(object):
def next(self):
5 / 0
def __next__(self):
5 / 0
class ExceptionTransform(object):