Merge pull request #1115 from quantopian/lazy_windows

Fix windows/conda builds
This commit is contained in:
Richard Frank
2016-04-12 19:41:44 -04:00
23 changed files with 476 additions and 240 deletions
+6 -6
View File
@@ -29,9 +29,9 @@ source:
requirements:
build:
- python
- setuptools
- setuptools >18.0
- numpy x.x
- setuptools_scm
- setuptools_scm >1.5.4
- cython ==0.22.1
run:
@@ -42,7 +42,7 @@ test:
# Python imports
imports:
- bcolz
# - bcolz.tests
- bcolz.tests
# commands:
# You can put test commands to be run here. Use this to test that the
@@ -52,9 +52,9 @@ test:
# You can also put a file called run_test.py in the recipe that will be run
# at test time.
# requires:
# - mock
# - unittest2 ; python_version < 2.7
requires:
- mock
- unittest2 # [py26]
# Put any additional test requirements here. For example
# - nose
-3
View File
@@ -33,9 +33,6 @@ cyordereddict==0.2.2
# faster array ops.
bottleneck==1.0.0
# lru_cache
functools32==3.2.3.post2;python_version<'3.0'
contextlib2==0.4.0
# networkx requires decorator
+3 -2
View File
@@ -140,8 +140,9 @@ def _filter_requirements(lines_iter, filter_names=None,
yield line
# We don't currently have any known upper bounds.
REQ_UPPER_BOUNDS = {}
REQ_UPPER_BOUNDS = {
'bcolz': '<1'
}
def _with_bounds(req):
+20 -12
View File
@@ -1081,10 +1081,10 @@ class TestBeforeTradingStart(TestCase):
cls.trading_days[-1], 2, 50
)
cls.minute_reader = BcolzMinuteBarReader(cls.tempdir.path)
cls.adj_reader = cls.create_adjustments_reader()
minute_reader = BcolzMinuteBarReader(cls.tempdir.path)
adj_reader = cls.create_adjustments_reader()
cls.daily_path = cls.tempdir.getpath("testdaily.bcolz")
daily_path = cls.tempdir.getpath("testdaily.bcolz")
dfs = {
1: create_daily_df_for_asset(cls.env, cls.trading_days[0],
cls.trading_days[-1]),
@@ -1094,7 +1094,7 @@ class TestBeforeTradingStart(TestCase):
cls.trading_days[-1])
}
daily_writer = DailyBarWriterFromDataFrames(dfs)
daily_writer.write(cls.daily_path, cls.trading_days, dfs)
daily_writer.write(daily_path, cls.trading_days, dfs)
cls.sim_params = SimulationParameters(
period_start=cls.trading_days[1],
@@ -1105,9 +1105,9 @@ class TestBeforeTradingStart(TestCase):
cls.data_portal = DataPortal(
env=cls.env,
equity_daily_reader=BcolzDailyBarReader(cls.daily_path),
equity_minute_reader=cls.minute_reader,
adjustment_reader=cls.adj_reader
equity_daily_reader=BcolzDailyBarReader(daily_path),
equity_minute_reader=minute_reader,
adjustment_reader=adj_reader
)
@classmethod
@@ -1131,15 +1131,15 @@ class TestBeforeTradingStart(TestCase):
# Mergers and Dividends are not tested, but we need to have these
# anyway
mergers = pd.DataFrame({}, columns=['effective_date', 'ratio', 'sid'])
mergers.effective_date = mergers.effective_date.astype(int)
mergers.ratio = mergers.ratio.astype(float)
mergers.sid = mergers.sid.astype(int)
mergers.effective_date = mergers.effective_date.astype(np.int64)
mergers.ratio = mergers.ratio.astype(np.float64)
mergers.sid = mergers.sid.astype(np.int64)
dividends = pd.DataFrame({}, columns=['ex_date', 'record_date',
'declared_date', 'pay_date',
'amount', 'sid'])
dividends.amount = dividends.amount.astype(float)
dividends.sid = dividends.sid.astype(int)
dividends.amount = dividends.amount.astype(np.float64)
dividends.sid = dividends.sid.astype(np.int64)
adj_writer.write(splits, mergers, dividends)
@@ -1147,6 +1147,8 @@ class TestBeforeTradingStart(TestCase):
@classmethod
def tearDownClass(cls):
del cls.data_portal
del cls.env
cls.tempdir.cleanup()
def test_data_in_bts_minute(self):
@@ -1478,6 +1480,7 @@ class TestAlgoScript(TestCase):
@classmethod
def tearDownClass(cls):
del cls.data_portal
del cls.env
cls.tempdir.cleanup()
teardown_logger(cls)
@@ -1867,6 +1870,7 @@ class TestGetDatetime(TestCase):
@classmethod
def tearDownClass(cls):
del cls.data_portal
del cls.env
teardown_logger(cls)
cls.tempdir.cleanup()
@@ -1946,6 +1950,7 @@ class TestTradingControls(TestCase):
@classmethod
def tearDownClass(cls):
del cls.data_portal
del cls.env
cls.tempdir.cleanup()
@@ -2362,6 +2367,7 @@ class TestAccountControls(TestCase):
@classmethod
def tearDownClass(cls):
del cls.data_portal
del cls.env
cls.tempdir.cleanup()
@@ -2529,6 +2535,7 @@ class TestFutureFlip(TestCase):
@classmethod
def tearDownClass(cls):
del cls.data_portal
cls.tempdir.cleanup()
@skip
@@ -2648,6 +2655,7 @@ class TestOrderCancelation(TestCase):
@classmethod
def tearDownClass(cls):
del cls.data_portal
cls.tempdir.cleanup()
@classmethod
+13 -9
View File
@@ -164,6 +164,11 @@ class TestAPIShim(TestCase):
env=cls.env
)
@classmethod
def tearDownClass(cls):
del cls.adj_reader
cls.tempdir.cleanup()
@classmethod
def build_daily_data(cls):
path = cls.tempdir.getpath("testdaily.bcolz")
@@ -203,24 +208,20 @@ class TestAPIShim(TestCase):
# Mergers and Dividends are not tested, but we need to have these
# anyway
mergers = pd.DataFrame({}, columns=['effective_date', 'ratio', 'sid'])
mergers.effective_date = mergers.effective_date.astype(int)
mergers.ratio = mergers.ratio.astype(float)
mergers.sid = mergers.sid.astype(int)
mergers.effective_date = mergers.effective_date.astype(np.int64)
mergers.ratio = mergers.ratio.astype(np.float64)
mergers.sid = mergers.sid.astype(np.int64)
dividends = pd.DataFrame({}, columns=['ex_date', 'record_date',
'declared_date', 'pay_date',
'amount', 'sid'])
dividends.amount = dividends.amount.astype(float)
dividends.sid = dividends.sid.astype(int)
dividends.amount = dividends.amount.astype(np.float64)
dividends.sid = dividends.sid.astype(np.int64)
adj_writer.write(splits, mergers, dividends)
return SQLiteAdjustmentReader(path)
@classmethod
def tearDownClass(cls):
cls.tempdir.cleanup()
def setUp(self):
self.data_portal = DataPortal(
self.env,
@@ -229,6 +230,9 @@ class TestAPIShim(TestCase):
adjustment_reader=self.adj_reader
)
def tearDown(self):
del self.data_portal
@classmethod
def create_algo(cls, code, filename=None, sim_params=None):
if sim_params is None:
+4
View File
@@ -131,6 +131,8 @@ class TestMinuteBarData(TestBarDataBase):
@classmethod
def tearDownClass(cls):
del cls.data_portal
del cls.adjustments_reader
cls.tempdir.cleanup()
@classmethod
@@ -564,6 +566,8 @@ class TestDailyBarData(TestBarDataBase):
@classmethod
def tearDownClass(cls):
del cls.data_portal
del cls.adjustments_reader
cls.tempdir.cleanup()
@classmethod
+1
View File
@@ -100,6 +100,7 @@ class TestBenchmark(TestCase):
@classmethod
def tearDownClass(cls):
del cls.data_portal
del cls.env
cls.tempdir.cleanup()
+1 -1
View File
@@ -286,7 +286,7 @@ class FinanceTestCase(TestCase):
else:
alternator = 1
tracker = PerformanceTracker(sim_params, self.env, data_portal)
tracker = PerformanceTracker(sim_params, self.env)
# replicate what tradesim does by going through every minute or day
# of the simulation and processing open orders each time
+20 -19
View File
@@ -108,11 +108,18 @@ class HistoryTestCaseBase(TestCase):
cls.create_data()
@classmethod
def tearDownClass(cls):
del cls.adj_reader
cls.tempdir.cleanup()
def setUp(self):
self.create_data_portal()
@classmethod
def create_data_portal(cls):
def tearDown(self):
del self.data_portal
def create_data_portal(self):
raise NotImplementedError()
@classmethod
@@ -169,10 +176,6 @@ class HistoryTestCaseBase(TestCase):
}
})
@classmethod
def tearDownClass(cls):
cls.tempdir.cleanup()
@classmethod
def create_adjustments_reader(cls):
path = cls.tempdir.getpath("test_adjustments.db")
@@ -470,12 +473,11 @@ MINUTE_FIELD_INFO = {
class MinuteEquityHistoryTestCase(HistoryTestCaseBase):
@classmethod
def create_data_portal(cls):
cls.data_portal = DataPortal(
cls.env,
equity_minute_reader=BcolzMinuteBarReader(cls.tempdir.path),
adjustment_reader=cls.adj_reader
def create_data_portal(self):
self.data_portal = DataPortal(
self.env,
equity_minute_reader=BcolzMinuteBarReader(self.tempdir.path),
adjustment_reader=self.adj_reader
)
@classmethod
@@ -1016,15 +1018,14 @@ class MinuteEquityHistoryTestCase(HistoryTestCaseBase):
class DailyEquityHistoryTestCase(HistoryTestCaseBase):
@classmethod
def create_data_portal(cls):
daily_path = cls.tempdir.getpath("testdaily.bcolz")
def create_data_portal(self):
daily_path = self.tempdir.getpath("testdaily.bcolz")
cls.data_portal = DataPortal(
cls.env,
self.data_portal = DataPortal(
self.env,
equity_daily_reader=BcolzDailyBarReader(daily_path),
equity_minute_reader=BcolzMinuteBarReader(cls.tempdir.path),
adjustment_reader=cls.adj_reader
equity_minute_reader=BcolzMinuteBarReader(self.tempdir.path),
adjustment_reader=self.adj_reader
)
@classmethod
+62
View File
@@ -1,6 +1,8 @@
"""
Tests for zipline.utils.memoize.
"""
from collections import defaultdict
import gc
from unittest import TestCase
from zipline.utils.memoize import remember_last
@@ -32,3 +34,63 @@ class TestRememberLast(TestCase):
# Calling the old value should still increment the counter.
self.assertEqual((func(1), call_count[0]), (1, 3))
self.assertEqual((func(1), call_count[0]), (1, 3))
def test_remember_last_method(self):
call_count = defaultdict(int)
class clz(object):
@remember_last
def func(self, x):
call_count[(self, x)] += 1
return x
inst1 = clz()
self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 1}))
# Calling again with the same argument should just re-use the old
# value, which means func shouldn't get called again.
self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 1}))
# Calling with a new value should increment the counter.
self.assertEqual((inst1.func(2), call_count), (2, {(inst1, 1): 1,
(inst1, 2): 1}))
self.assertEqual((inst1.func(2), call_count), (2, {(inst1, 1): 1,
(inst1, 2): 1}))
# Calling the old value should still increment the counter.
self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 2,
(inst1, 2): 1}))
self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 2,
(inst1, 2): 1}))
inst2 = clz()
self.assertEqual((inst2.func(1), call_count),
(1, {(inst1, 1): 2, (inst1, 2): 1,
(inst2, 1): 1}))
self.assertEqual((inst2.func(1), call_count),
(1, {(inst1, 1): 2, (inst1, 2): 1,
(inst2, 1): 1}))
self.assertEqual((inst2.func(2), call_count),
(2, {(inst1, 1): 2, (inst1, 2): 1,
(inst2, 1): 1, (inst2, 2): 1}))
self.assertEqual((inst2.func(2), call_count),
(2, {(inst1, 1): 2, (inst1, 2): 1,
(inst2, 1): 1, (inst2, 2): 1}))
self.assertEqual((inst2.func(1), call_count),
(1, {(inst1, 1): 2, (inst1, 2): 1,
(inst2, 1): 2, (inst2, 2): 1}))
self.assertEqual((inst2.func(1), call_count),
(1, {(inst1, 1): 2, (inst1, 2): 1,
(inst2, 1): 2, (inst2, 2): 1}))
# Remove the above references to the instances and ensure that
# remember_last has not made its own.
del inst1, inst2
call_count.clear()
while gc.collect():
pass
self.assertFalse([inst for inst in gc.get_objects()
if type(inst) == clz])
+49 -81
View File
@@ -50,7 +50,7 @@ from zipline.utils.serialization_utils import (
loads_with_persistent_ids, dumps_with_persistent_ids
)
from zipline.testing.core import create_data_portal_from_trade_history, \
create_empty_splits_mergers_frame, FakeDataPortal
create_empty_splits_mergers_frame
logger = logging.getLogger('Test Perf Tracking')
@@ -167,7 +167,7 @@ def calculate_results(sim_params,
splits = splits or {}
commissions = commissions or {}
perf_tracker = perf.PerformanceTracker(sim_params, env, data_portal)
perf_tracker = perf.PerformanceTracker(sim_params, env)
results = []
@@ -189,8 +189,10 @@ def calculate_results(sim_params,
except KeyError:
pass
msg = perf_tracker.handle_market_close_daily(date)
perf_tracker.position_tracker.sync_last_sale_prices(date, False)
msg = perf_tracker.handle_market_close_daily(date, data_portal)
perf_tracker.position_tracker.sync_last_sale_prices(
date, False, data_portal,
)
msg['account'] = perf_tracker.get_account(True)
results.append(copy.deepcopy(msg))
return results
@@ -265,9 +267,7 @@ class TestSplitPerformance(unittest.TestCase):
def test_multiple_splits(self):
# if multiple positions all have splits at the same time, verify that
# the total leftover cash is correct
perf_tracker = perf.PerformanceTracker(
self.sim_params, self.env, FakeDataPortal()
)
perf_tracker = perf.PerformanceTracker(self.sim_params, self.env)
asset1 = self.env.asset_finder.retrieve_asset(1)
asset2 = self.env.asset_finder.retrieve_asset(2)
@@ -1240,11 +1240,10 @@ class TestPositionPerformance(unittest.TestCase):
txn1 = create_txn(self.asset1, trades_1[0].dt, 10.0, 100)
txn2 = create_txn(self.asset2, trades_1[0].dt, 10.0, -100)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp.position_tracker = pt
pt.execute_transaction(txn1)
pp.handle_execution(txn1)
@@ -1252,7 +1251,7 @@ class TestPositionPerformance(unittest.TestCase):
pp.handle_execution(txn2)
dt = trades_1[-2].dt
pt.sync_last_sale_prices(dt, False)
pt.sync_last_sale_prices(dt, False, data_portal)
pp.calculate_performance()
@@ -1280,7 +1279,7 @@ class TestPositionPerformance(unittest.TestCase):
net_liquidation=1000.0)
dt = trades_1[-1].dt
pt.sync_last_sale_prices(dt, False)
pt.sync_last_sale_prices(dt, False, data_portal)
pp.calculate_performance()
@@ -1333,11 +1332,10 @@ class TestPositionPerformance(unittest.TestCase):
self.sim_params,
{1: trades})
txn = create_txn(self.asset1, trades[1].dt, 10.0, 1000)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp.position_tracker = pt
pt.execute_transaction(txn)
@@ -1355,7 +1353,7 @@ class TestPositionPerformance(unittest.TestCase):
shorts_count=0)
# Validate that the account attributes were updated.
pt.sync_last_sale_prices(trades[-2].dt, False)
pt.sync_last_sale_prices(trades[-2].dt, False, data_portal)
# Validate that the account attributes were updated.
account = pp.as_account()
@@ -1373,7 +1371,7 @@ class TestPositionPerformance(unittest.TestCase):
net_liquidation=1000.0)
# now simulate a price jump to $11
pt.sync_last_sale_prices(trades[-1].dt, False)
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
pp.calculate_performance()
@@ -1425,11 +1423,10 @@ class TestPositionPerformance(unittest.TestCase):
self.sim_params,
{1: trades})
txn = create_txn(self.asset1, trades[1].dt, 10.0, 100)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal,
period_open=self.sim_params.period_start,
period_close=self.sim_params.period_end)
pp.position_tracker = pt
@@ -1444,7 +1441,7 @@ class TestPositionPerformance(unittest.TestCase):
# stocks with a last sale price of 0.
self.assertEqual(pp.positions[1].last_sale_price, 10.0)
pt.sync_last_sale_prices(trades[-1].dt, False)
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
pp.calculate_performance()
@@ -1544,18 +1541,17 @@ single short-sale transaction"""
{1: trades})
txn = create_txn(self.asset1, trades[1].dt, 10.0, -100)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(
1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp.position_tracker = pt
pt.execute_transaction(txn)
pp.handle_execution(txn)
pt.sync_last_sale_prices(trades_1[-1].dt, False)
pt.sync_last_sale_prices(trades_1[-1].dt, False, data_portal)
pp.calculate_performance()
@@ -1611,7 +1607,7 @@ single short-sale transaction"""
# simulate a rollover to a new period
pp.rollover()
pt.sync_last_sale_prices(trades[-1].dt, False)
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
pp.calculate_performance()
@@ -1665,17 +1661,16 @@ single short-sale transaction"""
)
# now run a performance period encompassing the entire trade sample.
ptTotal = perf.PositionTracker(self.env.asset_finder, data_portal,
ptTotal = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
ppTotal.position_tracker = pt
ptTotal.execute_transaction(txn)
ppTotal.handle_execution(txn)
ptTotal.sync_last_sale_prices(trades[-1].dt, False)
ptTotal.sync_last_sale_prices(trades[-1].dt, False, data_portal)
ppTotal.calculate_performance()
@@ -1778,11 +1773,10 @@ cost of sole txn in test"
)
txn = create_txn(self.asset3, trades[1].dt, 10.0, 1)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp.position_tracker = pt
pt.execute_transaction(txn)
@@ -1795,7 +1789,7 @@ cost of sole txn in test"
# stocks with a last sale price of 0.
self.assertEqual(pp.positions[3].last_sale_price, 10.0)
pt.sync_last_sale_prices(trades[-1].dt, False)
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
pp.calculate_performance()
self.assertEqual(
@@ -1899,17 +1893,16 @@ single short-sale transaction"""
trades_1 = trades[:-2]
txn = create_txn(self.asset3, trades[0].dt, 10.0, -1)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp.position_tracker = pt
pt.execute_transaction(txn)
pp.handle_execution(txn)
pt.sync_last_sale_prices(trades[-3].dt, False)
pt.sync_last_sale_prices(trades[-3].dt, False, data_portal)
pp.calculate_performance()
self.assertEqual(
@@ -1969,7 +1962,7 @@ single short-sale transaction"""
# simulate a rollover to a new period
pp.rollover()
pt.sync_last_sale_prices(trades_2[-1].dt, False)
pt.sync_last_sale_prices(trades_2[-1].dt, False, data_portal)
pp.calculate_performance()
self.assertEqual(
@@ -2027,21 +2020,20 @@ single short-sale transaction"""
)
# now run a performance period encompassing the entire trade sample.
ptTotal = perf.PositionTracker(self.env.asset_finder, data_portal,
ptTotal = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
ppTotal.position_tracker = ptTotal
for trade in trades_1:
ptTotal.sync_last_sale_prices(trade.dt, False)
ptTotal.sync_last_sale_prices(trade.dt, False, data_portal)
ptTotal.execute_transaction(txn)
ppTotal.handle_execution(txn)
for trade in trades_2:
ptTotal.sync_last_sale_prices(trade.dt, False)
ptTotal.sync_last_sale_prices(trade.dt, False, data_portal)
ppTotal.calculate_performance()
@@ -2144,11 +2136,10 @@ trade after cover"""
short_txn = create_txn(self.asset1, trades[1].dt, 10.0, -100)
cover_txn = create_txn(self.asset1, trades[6].dt, 7.0, 100)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp.position_tracker = pt
pt.execute_transaction(short_txn)
@@ -2156,7 +2147,7 @@ trade after cover"""
pt.execute_transaction(cover_txn)
pp.handle_execution(cover_txn)
pt.sync_last_sale_prices(trades[-1].dt, False)
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
pp.calculate_performance()
@@ -2231,13 +2222,12 @@ shares in position"
self.sim_params,
{1: trades})
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(
1000.0,
self.env.asset_finder,
self.sim_params.data_frequency,
data_portal,
period_open=self.sim_params.period_start,
period_close=self.sim_params.trading_days[-1]
)
@@ -2264,7 +2254,7 @@ shares in position"
"should have a cost basis of 11"
)
pt.sync_last_sale_prices(dt, False)
pt.sync_last_sale_prices(dt, False, data_portal)
pp.calculate_performance()
@@ -2281,7 +2271,7 @@ shares in position"
pp.handle_execution(sale_txn)
dt = down_tick.dt
pt.sync_last_sale_prices(dt, False)
pt.sync_last_sale_prices(dt, False, data_portal)
pp.calculate_performance()
self.assertEqual(
@@ -2299,11 +2289,10 @@ shares in position"
self.assertEqual(pp.pnl, -800, "this period goes from +400 to -400")
pt3 = perf.PositionTracker(self.env.asset_finder, data_portal,
pt3 = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp3 = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp3.position_tracker = pt3
average_cost = 0
@@ -2317,7 +2306,7 @@ shares in position"
pp3.handle_execution(sale_txn)
trades.append(down_tick)
pt3.sync_last_sale_prices(trades[-1].dt, False)
pt3.sync_last_sale_prices(trades[-1].dt, False, data_portal)
pp3.calculate_performance()
self.assertEqual(
@@ -2351,18 +2340,11 @@ shares in position"
)
cost_bases = [10, 10, 0, 8, 9, 9, 13, 13.5]
trades = factory.create_trade_history(*history_args)
transactions = factory.create_txn_history(*history_args)
data_portal = create_data_portal_from_trade_history(
self.env,
self.tempdir,
self.sim_params,
{1: trades})
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, data_portal,
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency)
pp.position_tracker = pt
@@ -2413,22 +2395,8 @@ class TestPositionTracker(unittest.TestCase):
sim_params = factory.create_simulation_parameters(
num_days=4, env=self.env
)
trades = factory.create_trade_history(
1,
[10, 10, 10, 11],
[100, 100, 100, 100],
oneday,
sim_params,
env=self.env
)
data_portal = create_data_portal_from_trade_history(
self.env,
self.tempdir,
sim_params,
{1: trades})
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
sim_params.data_frequency)
pos_stats = pt.stats()
@@ -2450,7 +2418,7 @@ class TestPositionTracker(unittest.TestCase):
self.assertNotIsInstance(val, (bool, np.bool_))
def test_position_values_and_exposures(self):
pt = perf.PositionTracker(self.env.asset_finder, None, None)
pt = perf.PositionTracker(self.env.asset_finder, None)
dt = pd.Timestamp("1984/03/06 3:00PM")
pos1 = perf.Position(1, amount=np.float64(10.0),
last_sale_date=dt, last_sale_price=10)
@@ -2482,7 +2450,7 @@ class TestPositionTracker(unittest.TestCase):
self.assertEqual(100 - 200 + 300000 - 400000, pos_stats.net_exposure)
def test_update_positions(self):
pt = perf.PositionTracker(self.env.asset_finder, None, None)
pt = perf.PositionTracker(self.env.asset_finder, None)
dt = pd.Timestamp("2014/01/01 3:00PM")
pos1 = perf.Position(1, amount=np.float64(10.0),
last_sale_date=dt, last_sale_price=10)
+5 -3
View File
@@ -1,6 +1,7 @@
from numpy import (
float64,
uint32
uint32,
int64,
)
from bcolz import ctable
@@ -37,8 +38,9 @@ class DailyBarWriterFromDataFrames(BcolzDailyBarWriter):
return array.astype(uint32)
elif colname == 'day':
nanos_per_second = (1000 * 1000 * 1000)
self.check_uint_safe(arrmax.view(int) / nanos_per_second, colname)
return (array.view(int) / nanos_per_second).astype(uint32)
self.check_uint_safe(arrmax.view(int64) / nanos_per_second,
colname)
return (array.view(int64) / nanos_per_second).astype(uint32)
@staticmethod
def check_uint_safe(value, colname):
+11 -9
View File
@@ -492,7 +492,6 @@ class TradingAlgorithm(object):
self.perf_tracker = PerformanceTracker(
sim_params=self.sim_params,
env=self.trading_environment,
data_portal=self.data_portal
)
# Set the dt initially to the period start by forcing it to change.
@@ -603,14 +602,17 @@ class TradingAlgorithm(object):
# Create zipline and loop through simulated_trading.
# Each iteration returns a perf dictionary
perfs = []
for perf in self.get_generator():
perfs.append(perf)
try:
perfs = []
for perf in self.get_generator():
perfs.append(perf)
# convert perf dict to pandas dataframe
daily_stats = self._create_daily_stats(perfs)
# convert perf dict to pandas dataframe
daily_stats = self._create_daily_stats(perfs)
self.analyze(daily_stats)
self.analyze(daily_stats)
finally:
self.data_portal = None
return daily_stats
@@ -1057,7 +1059,7 @@ class TradingAlgorithm(object):
def updated_portfolio(self):
if self.portfolio_needs_update:
self.perf_tracker.position_tracker.sync_last_sale_prices(
self.datetime, self._in_before_trading_start)
self.datetime, self._in_before_trading_start, self.data_portal)
self._portfolio = \
self.perf_tracker.get_portfolio(self.performance_needs_update)
self.portfolio_needs_update = False
@@ -1071,7 +1073,7 @@ class TradingAlgorithm(object):
def updated_account(self):
if self.account_needs_update:
self.perf_tracker.position_tracker.sync_last_sale_prices(
self.datetime, self._in_before_trading_start)
self.datetime, self._in_before_trading_start, self.data_portal)
self._account = \
self.perf_tracker.get_account(self.performance_needs_update)
self.account_needs_update = False
+2 -7
View File
@@ -31,13 +31,12 @@ from zipline.data.us_equity_loader import (
)
from zipline.utils import tradingcalendar
from zipline.utils.compat import lru_cache
from zipline.utils.math_utils import (
nansum,
nanmean,
nanstd
)
from zipline.utils.memoize import remember_last
from zipline.utils.memoize import remember_last, weak_lru_cache
from zipline.errors import (
NoTradeDataAvailableTooEarly,
NoTradeDataAvailableTooLate,
@@ -1091,10 +1090,6 @@ class DataPortal(object):
return data
@remember_last
def _get_market_minutes_for_day(self, end_date):
return self.env.market_minutes_for_day(pd.Timestamp(end_date))
def _get_history_daily_window_equities(
self, assets, days_for_window, end_dt, field_to_use):
ends_at_midnight = end_dt.hour == 0 and end_dt.minute == 0
@@ -1686,7 +1681,7 @@ class DataPortal(object):
else:
return [assets] if isinstance(assets, Asset) else []
@lru_cache(20)
@weak_lru_cache(20)
def _get_minute_count_for_transform(self, ending_minute, days_count):
# cache size picked somewhat loosely. this code exists purely to
# handle deprecated API.
+5 -5
View File
@@ -163,10 +163,10 @@ class BcolzMinuteBarMetadata(object):
'first_trading_day': str(self.first_trading_day.date()),
'market_opens': self.market_opens.values.
astype('datetime64[m]').
astype(int).tolist(),
astype(np.int64).tolist(),
'market_closes': self.market_closes.values.
astype('datetime64[m]').
astype(int).tolist(),
astype(np.int64).tolist(),
'ohlc_ratio': self.ohlc_ratio,
}
with open(self.metadata_path(rootdir), 'w+') as fp:
@@ -603,10 +603,10 @@ class BcolzMinuteBarReader(object):
self._market_opens = metadata.market_opens
self._market_open_values = metadata.market_opens.values.\
astype('datetime64[m]').astype(int)
astype('datetime64[m]').astype(np.int64)
self._market_closes = metadata.market_closes
self._market_close_values = metadata.market_closes.values.\
astype('datetime64[m]').astype(int)
astype('datetime64[m]').astype(np.int64)
self._ohlc_inverse = 1.0 / metadata.ohlc_ratio
@@ -643,7 +643,7 @@ class BcolzMinuteBarReader(object):
"""
market_opens = self._market_opens.values.astype('datetime64[m]')
market_closes = self._market_closes.values.astype('datetime64[m]')
minutes_per_day = (market_closes - market_opens).astype(int)
minutes_per_day = (market_closes - market_opens).astype(np.int64)
early_indices = np.where(
minutes_per_day != US_EQUITIES_MINUTES_PER_DAY - 1)[0]
regular_closes = market_opens[early_indices] + timedelta64(
-3
View File
@@ -133,7 +133,6 @@ class PerformancePeriod(object):
starting_cash,
asset_finder,
data_frequency,
data_portal,
period_open=None,
period_close=None,
keep_transactions=True,
@@ -144,8 +143,6 @@ class PerformancePeriod(object):
self.asset_finder = asset_finder
self.data_frequency = data_frequency
self._data_portal = data_portal
self.period_open = period_open
self.period_close = period_close
@@ -119,14 +119,9 @@ def calc_gross_value(long_value, short_value):
class PositionTracker(object):
def __init__(self, asset_finder, data_portal, data_frequency):
def __init__(self, asset_finder, data_frequency):
self.asset_finder = asset_finder
# FIXME really want to avoid storing a data portal here,
# but the path to get to maybe_create_close_position_transaction
# is long and tortuous
self._data_portal = data_portal
# sid => position object
self.positions = positiondict()
# Arrays for quick calculations of positions value
@@ -316,12 +311,12 @@ class PositionTracker(object):
return net_cash_payment
def maybe_create_close_position_transaction(self, asset, dt):
def maybe_create_close_position_transaction(self, asset, dt, data_portal):
if not self.positions.get(asset):
return None
amount = self.positions.get(asset).amount
price = self._data_portal.get_spot_value(
price = data_portal.get_spot_value(
asset, 'price', dt, self.data_frequency)
# Get the last traded price if price is no longer available
@@ -372,8 +367,8 @@ class PositionTracker(object):
positions.append(pos.to_dict())
return positions
def sync_last_sale_prices(self, dt, handle_non_market_minutes):
data_portal = self._data_portal
def sync_last_sale_prices(self, dt, handle_non_market_minutes,
data_portal):
if not handle_non_market_minutes:
for asset, position in iteritems(self.positions):
last_sale_price = data_portal.get_spot_value(
+20 -25
View File
@@ -78,7 +78,7 @@ class PerformanceTracker(object):
"""
Tracks the performance of the algorithm.
"""
def __init__(self, sim_params, env, data_portal):
def __init__(self, sim_params, env):
self.sim_params = sim_params
self.env = env
@@ -101,15 +101,8 @@ class PerformanceTracker(object):
self.trading_days = all_trading_days[mask]
self._data_portal = data_portal
if data_portal is not None:
self._adjustment_reader = data_portal._adjustment_reader
else:
self._adjustment_reader = None
self.position_tracker = PositionTracker(
asset_finder=env.asset_finder,
data_portal=data_portal,
data_frequency=self.sim_params.data_frequency)
if self.emission_rate == 'daily':
@@ -132,7 +125,6 @@ class PerformanceTracker(object):
# initial cash is your capital base.
starting_cash=self.capital_base,
data_frequency=self.sim_params.data_frequency,
data_portal=data_portal,
# the cumulative period will be calculated over the entire test.
period_open=self.period_start,
period_close=self.period_end,
@@ -152,7 +144,6 @@ class PerformanceTracker(object):
# initial cash is your capital base.
starting_cash=self.capital_base,
data_frequency=self.sim_params.data_frequency,
data_portal=data_portal,
# the daily period will be calculated for the market day
period_open=self.market_open,
period_close=self.market_close,
@@ -264,13 +255,13 @@ class PerformanceTracker(object):
self.cumulative_performance.handle_commission(cost)
self.todays_performance.handle_commission(cost)
def process_close_position(self, asset, dt):
def process_close_position(self, asset, dt, data_portal):
txn = self.position_tracker.\
maybe_create_close_position_transaction(asset, dt)
maybe_create_close_position_transaction(asset, dt, data_portal)
if txn:
self.process_transaction(txn)
def check_upcoming_dividends(self, next_trading_day):
def check_upcoming_dividends(self, next_trading_day, adjustment_reader):
"""
Check if we currently own any stocks with dividends whose ex_date is
the next trading day. Track how much we should be payed on those
@@ -280,7 +271,7 @@ class PerformanceTracker(object):
is the next trading day. Apply all such benefits, then recalculate
performance.
"""
if self._adjustment_reader is None:
if adjustment_reader is None:
return
position_tracker = self.position_tracker
held_sids = set(position_tracker.positions)
@@ -291,10 +282,10 @@ class PerformanceTracker(object):
if held_sids:
asset_finder = self.env.asset_finder
cash_dividends = self._adjustment_reader.\
cash_dividends = adjustment_reader.\
get_dividends_with_ex_date(held_sids, next_trading_day,
asset_finder)
stock_dividends = self._adjustment_reader.\
stock_dividends = adjustment_reader.\
get_stock_dividends_with_ex_date(held_sids, next_trading_day,
asset_finder)
@@ -310,7 +301,7 @@ class PerformanceTracker(object):
self.cumulative_performance.handle_dividends_paid(net_cash_payment)
self.todays_performance.handle_dividends_paid(net_cash_payment)
def handle_minute_close(self, dt):
def handle_minute_close(self, dt, data_portal):
"""
Handles the close of the given minute. This includes handling
market-close functions if the given minute is the end of the market
@@ -327,7 +318,7 @@ class PerformanceTracker(object):
A tuple of the minute perf packet and daily perf packet.
If the market day has not ended, the daily perf packet is None.
"""
self.position_tracker.sync_last_sale_prices(dt, False)
self.position_tracker.sync_last_sale_prices(dt, False, data_portal)
self.update_performance()
todays_date = normalize_date(dt)
account = self.get_account(False)
@@ -346,16 +337,18 @@ class PerformanceTracker(object):
# if this is the close, update dividends for the next day.
# Return the performance tuple
if dt == self.market_close:
return minute_packet, self._handle_market_close(todays_date)
return minute_packet, self._handle_market_close(
todays_date, data_portal._adjustment_reader,
)
else:
return minute_packet, None
def handle_market_close_daily(self, dt):
def handle_market_close_daily(self, dt, data_portal):
"""
Function called after handle_data when running with daily emission
rate.
"""
self.position_tracker.sync_last_sale_prices(dt, False)
self.position_tracker.sync_last_sale_prices(dt, False, data_portal)
self.update_performance()
completed_date = self.day
account = self.get_account(False)
@@ -368,11 +361,12 @@ class PerformanceTracker(object):
benchmark_value,
account.leverage)
daily_packet = self._handle_market_close(completed_date)
daily_packet = self._handle_market_close(
completed_date, data_portal._adjustment_reader,
)
return daily_packet
def _handle_market_close(self, completed_date):
def _handle_market_close(self, completed_date, adjustment_reader):
# increment the day counter before we move markers forward.
self.day_count += 1.0
@@ -406,7 +400,8 @@ class PerformanceTracker(object):
return daily_update
# Check for any dividends, then return the daily perf packet
self.check_upcoming_dividends(next_trading_day=next_trading_day)
self.check_upcoming_dividends(next_trading_day=next_trading_day,
adjustment_reader=adjustment_reader)
return daily_update
def handle_simulation_end(self):
+20 -23
View File
@@ -97,20 +97,9 @@ class AlgorithmSimulator(object):
Main generator work loop.
"""
algo = self.algo
algo.data_portal = self.data_portal
handle_data = algo.event_manager.handle_data
current_data = self.current_data
data_portal = self.data_portal
# can't cache a pointer to algo.perf_tracker because we're not
# guaranteed that the algo doesn't swap out perf trackers during
# its lifetime.
# likewise, we can't cache a pointer to the blotter.
algo.perf_tracker.position_tracker.data_portal = data_portal
def every_bar(dt_to_use):
def every_bar(dt_to_use, current_data=self.current_data,
handle_data=algo.event_manager.handle_data):
# called every tick (minute or day).
self.simulation_dt = dt_to_use
@@ -152,7 +141,8 @@ class AlgorithmSimulator(object):
self.algo.account_needs_update = True
self.algo.performance_needs_update = True
def once_a_day(midnight_dt):
def once_a_day(midnight_dt, current_data=self.current_data,
data_portal=self.data_portal):
# Get the positions before updating the date so that prices are
# fetched for trading close instead of midnight
positions = algo.perf_tracker.position_tracker.positions
@@ -183,11 +173,15 @@ class AlgorithmSimulator(object):
# call before trading start
algo.before_trading_start(current_data)
def handle_benchmark(date):
def handle_benchmark(date, benchmark_source=self.benchmark_source):
algo.perf_tracker.all_benchmark_returns[date] = \
self.benchmark_source.get_value(date)
benchmark_source.get_value(date)
def on_exit():
self.benchmark_source = self.current_data = self.data_portal = None
with ExitStack() as stack:
stack.callback(on_exit)
stack.enter_context(self.processor)
stack.enter_context(ZiplineAPI(self.algo))
@@ -245,8 +239,9 @@ class AlgorithmSimulator(object):
assets_to_clear = \
[asset for asset in position_assets if past_auto_close_date(asset)]
perf_tracker = algo.perf_tracker
data_portal = self.data_portal
for asset in assets_to_clear:
perf_tracker.process_close_position(asset, dt)
perf_tracker.process_close_position(asset, dt, data_portal)
# Remove open orders for any sids that have reached their
# auto_close_date.
@@ -257,23 +252,25 @@ class AlgorithmSimulator(object):
for asset in assets_to_cancel:
blotter.cancel_all_orders_for_asset(asset)
@staticmethod
def _get_daily_message(dt, algo, perf_tracker):
def _get_daily_message(self, dt, algo, perf_tracker):
"""
Get a perf message for the given datetime.
"""
perf_message = perf_tracker.handle_market_close_daily(dt)
perf_message = perf_tracker.handle_market_close_daily(
dt, self.data_portal,
)
perf_message['daily_perf']['recorded_vars'] = algo.recorded_vars
return perf_message
@staticmethod
def _get_minute_message(dt, algo, perf_tracker):
def _get_minute_message(self, dt, algo, perf_tracker):
"""
Get a perf message for the given datetime.
"""
rvars = algo.recorded_vars
minute_message, daily_message = perf_tracker.handle_minute_close(dt)
minute_message, daily_message = perf_tracker.handle_minute_close(
dt, self.data_portal,
)
minute_message['minute_perf']['recorded_vars'] = rvars
if daily_message:
+1 -2
View File
@@ -16,9 +16,8 @@ import pandas as pd
from .utils.enum import enum
from zipline._protocol import BarData as _BarData
from zipline._protocol import BarData # noqa
BarData = _BarData
# Datasource type should completely determine the other fields of a
# message with its type.
+13 -10
View File
@@ -45,7 +45,8 @@ from zipline.utils.tradingcalendar import trading_days
import numpy as np
from numpy import (
float64,
uint32
uint32,
int64,
)
@@ -456,8 +457,9 @@ def make_trade_data_for_asset_info(dates,
sids = asset_info.keys()
date_field = 'day' if frequency == 'daily' else 'dt'
price_sid_deltas = np.arange(len(sids), dtype=float) * price_step_by_sid
price_date_deltas = np.arange(len(dates), dtype=float) * price_step_by_date
price_sid_deltas = np.arange(len(sids), dtype=float64) * price_step_by_sid
price_date_deltas = (np.arange(len(dates), dtype=float64) *
price_step_by_date)
prices = (price_sid_deltas + price_date_deltas[:, None]) + price_start
volume_sid_deltas = np.arange(len(sids)) * volume_step_by_sid
@@ -723,8 +725,9 @@ class DailyBarWriterFromDataFrames(BcolzDailyBarWriter):
return array.astype(uint32)
elif colname == 'day':
nanos_per_second = (1000 * 1000 * 1000)
self.check_uint_safe(arrmax.view(int) / nanos_per_second, colname)
return (array.view(int) / nanos_per_second).astype(uint32)
self.check_uint_safe(arrmax.view(int64) / nanos_per_second,
colname)
return (array.view(int64) / nanos_per_second).astype(uint32)
@staticmethod
def check_uint_safe(value, colname):
@@ -1198,8 +1201,8 @@ def create_mock_adjustments(tempdir, days, splits=None, dividends=None,
'pay_date': np.array([], dtype='datetime64[ns]'),
'record_date': np.array([], dtype='datetime64[ns]'),
'declared_date': np.array([], dtype='datetime64[ns]'),
'amount': np.array([], dtype=float),
'sid': np.array([], dtype=int),
'amount': np.array([], dtype=float64),
'sid': np.array([], dtype=int64),
}
dividends = pd.DataFrame(
data,
@@ -1360,9 +1363,9 @@ def create_empty_splits_mergers_frame():
return pd.DataFrame(
{
# Hackery to make the dtypes correct on an empty frame.
'effective_date': np.array([], dtype=int),
'ratio': np.array([], dtype=float),
'sid': np.array([], dtype=int),
'effective_date': np.array([], dtype=int64),
'ratio': np.array([], dtype=float64),
'sid': np.array([], dtype=int64),
},
index=pd.DatetimeIndex([]),
columns=['effective_date', 'ratio', 'sid'],
-7
View File
@@ -1,7 +0,0 @@
from six import PY2
if PY2:
from functools32 import lru_cache # noqa
else:
from functools import lru_cache # noqa
+215 -3
View File
@@ -1,8 +1,13 @@
"""
Tools for memoization of function results.
"""
from zipline.utils.compat import lru_cache
from weakref import WeakKeyDictionary
from collections import OrderedDict, Sequence
from functools import wraps
from itertools import compress
from weakref import WeakKeyDictionary, ref
from six.moves._thread import allocate_lock as Lock
from toolz.sandbox import unzip
class lazyval(object):
@@ -84,4 +89,211 @@ class classlazyval(lazyval):
return super(classlazyval, self).__get__(owner, owner)
remember_last = lru_cache(1)
def _weak_lru_cache(maxsize=100):
"""
Users should only access the lru_cache through its public API:
cache_info, cache_clear
The internals of the lru_cache are encapsulated for thread safety and
to allow the implementation to change.
"""
def decorating_function(
user_function, tuple=tuple, sorted=sorted, len=len,
KeyError=KeyError):
hits, misses = [0], [0]
kwd_mark = (object(),) # separates positional and keyword args
lock = Lock() # needed because OrderedDict isn't threadsafe
if maxsize is None:
cache = _WeakArgsDict() # cache without ordering or size limit
@wraps(user_function)
def wrapper(*args, **kwds):
key = args
if kwds:
key += kwd_mark + tuple(sorted(kwds.items()))
try:
result = cache[key]
hits[0] += 1
return result
except KeyError:
pass
result = user_function(*args, **kwds)
cache[key] = result
misses[0] += 1
return result
else:
# ordered least recent to most recent
cache = _WeakArgsOrderedDict()
cache_popitem = cache.popitem
cache_renew = cache.move_to_end
@wraps(user_function)
def wrapper(*args, **kwds):
key = args
if kwds:
key += kwd_mark + tuple(sorted(kwds.items()))
with lock:
try:
result = cache[key]
cache_renew(key) # record recent use of this key
hits[0] += 1
return result
except KeyError:
pass
result = user_function(*args, **kwds)
with lock:
cache[key] = result # record recent use of this key
misses[0] += 1
if len(cache) > maxsize:
# purge least recently used cache entry
cache_popitem(False)
return result
def cache_info():
"""Report cache statistics"""
with lock:
return hits[0], misses[0], maxsize, len(cache)
def cache_clear():
"""Clear the cache and cache statistics"""
with lock:
cache.clear()
hits[0] = misses[0] = 0
wrapper.cache_info = cache_info
wrapper.cache_clear = cache_clear
return wrapper
return decorating_function
class _WeakArgs(Sequence):
"""
Works with _WeakArgsDict to provide a weak cache for function args.
When any of those args are gc'd, the pair is removed from the cache.
"""
def __init__(self, items, dict_remove=None):
def remove(k, selfref=ref(self), dict_remove=dict_remove):
self = selfref()
if self is not None and dict_remove is not None:
dict_remove(self)
self._items, self._selectors = unzip(self._try_ref(item, remove)
for item in items)
self._items = tuple(self._items)
self._selectors = tuple(self._selectors)
def __getitem__(self, index):
return self._items[index]
def __len__(self):
return len(self._items)
@staticmethod
def _try_ref(item, callback):
try:
return ref(item, callback), True
except TypeError:
return item, False
@property
def alive(self):
return all(item() is not None
for item in compress(self._items, self._selectors))
def __eq__(self, other):
return self._items == other._items
def __hash__(self):
try:
return self.__hash
except AttributeError:
h = self.__hash = hash(self._items)
return h
class _WeakArgsDict(WeakKeyDictionary, object):
def __delitem__(self, key):
del self.data[_WeakArgs(key)]
def __getitem__(self, key):
return self.data[_WeakArgs(key)]
def __repr__(self):
return '%s(%r)' % (type(self).__name__, self.data)
def __setitem__(self, key, value):
self.data[_WeakArgs(key, self._remove)] = value
def __contains__(self, key):
try:
wr = _WeakArgs(key)
except TypeError:
return False
return wr in self.data
def pop(self, key, *args):
return self.data.pop(_WeakArgs(key), *args)
class _WeakArgsOrderedDict(_WeakArgsDict, object):
def __init__(self):
super(_WeakArgsOrderedDict, self).__init__()
self.data = OrderedDict()
def popitem(self, last=True):
while True:
key, value = self.data.popitem(last)
if key.alive:
return tuple(key), value
def move_to_end(self, key):
"""Move an existing element to the end.
Raises KeyError if the element does not exist.
"""
self[key] = self.pop(key)
def weak_lru_cache(maxsize=100):
"""Weak least-recently-used cache decorator.
If *maxsize* is set to None, the LRU features are disabled and the cache
can grow without bound.
Arguments to the cached function must be hashable. Any that are weak-
referenceable will be stored by weak reference. Once any of the args have
been garbage collected, the entry will be removed from the cache.
View the cache statistics named tuple (hits, misses, maxsize, currsize)
with f.cache_info(). Clear the cache and statistics with f.cache_clear().
See: http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used
"""
class desc(lazyval):
def __get__(self, instance, owner):
if instance is None:
return self
try:
return self._cache[instance]
except KeyError:
inst = ref(instance)
@_weak_lru_cache(maxsize)
@wraps(self._get)
def wrapper(*args, **kwargs):
return self._get(inst(), *args, **kwargs)
self._cache[instance] = wrapper
return wrapper
@_weak_lru_cache(maxsize)
def __call__(self, *args, **kwargs):
return self._get(*args, **kwargs)
return desc
remember_last = weak_lru_cache(1)