diff --git a/conda/bcolz/meta.yaml b/conda/bcolz/meta.yaml index 7b785622..5e81384d 100644 --- a/conda/bcolz/meta.yaml +++ b/conda/bcolz/meta.yaml @@ -29,9 +29,9 @@ source: requirements: build: - python - - setuptools + - setuptools >18.0 - numpy x.x - - setuptools_scm + - setuptools_scm >1.5.4 - cython ==0.22.1 run: @@ -42,7 +42,7 @@ test: # Python imports imports: - bcolz -# - bcolz.tests + - bcolz.tests # commands: # You can put test commands to be run here. Use this to test that the @@ -52,9 +52,9 @@ test: # You can also put a file called run_test.py in the recipe that will be run # at test time. -# requires: -# - mock -# - unittest2 ; python_version < 2.7 + requires: + - mock + - unittest2 # [py26] # Put any additional test requirements here. For example # - nose diff --git a/etc/requirements.txt b/etc/requirements.txt index a90571c8..7ffa9527 100644 --- a/etc/requirements.txt +++ b/etc/requirements.txt @@ -33,9 +33,6 @@ cyordereddict==0.2.2 # faster array ops. bottleneck==1.0.0 -# lru_cache -functools32==3.2.3.post2;python_version<'3.0' - contextlib2==0.4.0 # networkx requires decorator diff --git a/setup.py b/setup.py index edddf402..c7e5594d 100644 --- a/setup.py +++ b/setup.py @@ -140,8 +140,9 @@ def _filter_requirements(lines_iter, filter_names=None, yield line -# We don't currently have any known upper bounds. -REQ_UPPER_BOUNDS = {} +REQ_UPPER_BOUNDS = { + 'bcolz': '<1' +} def _with_bounds(req): diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 700818d0..bc89e5f9 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -1081,10 +1081,10 @@ class TestBeforeTradingStart(TestCase): cls.trading_days[-1], 2, 50 ) - cls.minute_reader = BcolzMinuteBarReader(cls.tempdir.path) - cls.adj_reader = cls.create_adjustments_reader() + minute_reader = BcolzMinuteBarReader(cls.tempdir.path) + adj_reader = cls.create_adjustments_reader() - cls.daily_path = cls.tempdir.getpath("testdaily.bcolz") + daily_path = cls.tempdir.getpath("testdaily.bcolz") dfs = { 1: create_daily_df_for_asset(cls.env, cls.trading_days[0], cls.trading_days[-1]), @@ -1094,7 +1094,7 @@ class TestBeforeTradingStart(TestCase): cls.trading_days[-1]) } daily_writer = DailyBarWriterFromDataFrames(dfs) - daily_writer.write(cls.daily_path, cls.trading_days, dfs) + daily_writer.write(daily_path, cls.trading_days, dfs) cls.sim_params = SimulationParameters( period_start=cls.trading_days[1], @@ -1105,9 +1105,9 @@ class TestBeforeTradingStart(TestCase): cls.data_portal = DataPortal( env=cls.env, - equity_daily_reader=BcolzDailyBarReader(cls.daily_path), - equity_minute_reader=cls.minute_reader, - adjustment_reader=cls.adj_reader + equity_daily_reader=BcolzDailyBarReader(daily_path), + equity_minute_reader=minute_reader, + adjustment_reader=adj_reader ) @classmethod @@ -1131,15 +1131,15 @@ class TestBeforeTradingStart(TestCase): # Mergers and Dividends are not tested, but we need to have these # anyway mergers = pd.DataFrame({}, columns=['effective_date', 'ratio', 'sid']) - mergers.effective_date = mergers.effective_date.astype(int) - mergers.ratio = mergers.ratio.astype(float) - mergers.sid = mergers.sid.astype(int) + mergers.effective_date = mergers.effective_date.astype(np.int64) + mergers.ratio = mergers.ratio.astype(np.float64) + mergers.sid = mergers.sid.astype(np.int64) dividends = pd.DataFrame({}, columns=['ex_date', 'record_date', 'declared_date', 'pay_date', 'amount', 'sid']) - dividends.amount = dividends.amount.astype(float) - dividends.sid = dividends.sid.astype(int) + dividends.amount = dividends.amount.astype(np.float64) + dividends.sid = dividends.sid.astype(np.int64) adj_writer.write(splits, mergers, dividends) @@ -1147,6 +1147,8 @@ class TestBeforeTradingStart(TestCase): @classmethod def tearDownClass(cls): + del cls.data_portal + del cls.env cls.tempdir.cleanup() def test_data_in_bts_minute(self): @@ -1478,6 +1480,7 @@ class TestAlgoScript(TestCase): @classmethod def tearDownClass(cls): + del cls.data_portal del cls.env cls.tempdir.cleanup() teardown_logger(cls) @@ -1867,6 +1870,7 @@ class TestGetDatetime(TestCase): @classmethod def tearDownClass(cls): + del cls.data_portal del cls.env teardown_logger(cls) cls.tempdir.cleanup() @@ -1946,6 +1950,7 @@ class TestTradingControls(TestCase): @classmethod def tearDownClass(cls): + del cls.data_portal del cls.env cls.tempdir.cleanup() @@ -2362,6 +2367,7 @@ class TestAccountControls(TestCase): @classmethod def tearDownClass(cls): + del cls.data_portal del cls.env cls.tempdir.cleanup() @@ -2529,6 +2535,7 @@ class TestFutureFlip(TestCase): @classmethod def tearDownClass(cls): + del cls.data_portal cls.tempdir.cleanup() @skip @@ -2648,6 +2655,7 @@ class TestOrderCancelation(TestCase): @classmethod def tearDownClass(cls): + del cls.data_portal cls.tempdir.cleanup() @classmethod diff --git a/tests/test_api_shim.py b/tests/test_api_shim.py index d33c3e23..b04fb09d 100644 --- a/tests/test_api_shim.py +++ b/tests/test_api_shim.py @@ -164,6 +164,11 @@ class TestAPIShim(TestCase): env=cls.env ) + @classmethod + def tearDownClass(cls): + del cls.adj_reader + cls.tempdir.cleanup() + @classmethod def build_daily_data(cls): path = cls.tempdir.getpath("testdaily.bcolz") @@ -203,24 +208,20 @@ class TestAPIShim(TestCase): # Mergers and Dividends are not tested, but we need to have these # anyway mergers = pd.DataFrame({}, columns=['effective_date', 'ratio', 'sid']) - mergers.effective_date = mergers.effective_date.astype(int) - mergers.ratio = mergers.ratio.astype(float) - mergers.sid = mergers.sid.astype(int) + mergers.effective_date = mergers.effective_date.astype(np.int64) + mergers.ratio = mergers.ratio.astype(np.float64) + mergers.sid = mergers.sid.astype(np.int64) dividends = pd.DataFrame({}, columns=['ex_date', 'record_date', 'declared_date', 'pay_date', 'amount', 'sid']) - dividends.amount = dividends.amount.astype(float) - dividends.sid = dividends.sid.astype(int) + dividends.amount = dividends.amount.astype(np.float64) + dividends.sid = dividends.sid.astype(np.int64) adj_writer.write(splits, mergers, dividends) return SQLiteAdjustmentReader(path) - @classmethod - def tearDownClass(cls): - cls.tempdir.cleanup() - def setUp(self): self.data_portal = DataPortal( self.env, @@ -229,6 +230,9 @@ class TestAPIShim(TestCase): adjustment_reader=self.adj_reader ) + def tearDown(self): + del self.data_portal + @classmethod def create_algo(cls, code, filename=None, sim_params=None): if sim_params is None: diff --git a/tests/test_bar_data.py b/tests/test_bar_data.py index bec4d34b..08ae9fdb 100644 --- a/tests/test_bar_data.py +++ b/tests/test_bar_data.py @@ -131,6 +131,8 @@ class TestMinuteBarData(TestBarDataBase): @classmethod def tearDownClass(cls): + del cls.data_portal + del cls.adjustments_reader cls.tempdir.cleanup() @classmethod @@ -564,6 +566,8 @@ class TestDailyBarData(TestBarDataBase): @classmethod def tearDownClass(cls): + del cls.data_portal + del cls.adjustments_reader cls.tempdir.cleanup() @classmethod diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index a0f6757b..2b2960f4 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -100,6 +100,7 @@ class TestBenchmark(TestCase): @classmethod def tearDownClass(cls): + del cls.data_portal del cls.env cls.tempdir.cleanup() diff --git a/tests/test_finance.py b/tests/test_finance.py index 142272a7..5ba73a96 100644 --- a/tests/test_finance.py +++ b/tests/test_finance.py @@ -286,7 +286,7 @@ class FinanceTestCase(TestCase): else: alternator = 1 - tracker = PerformanceTracker(sim_params, self.env, data_portal) + tracker = PerformanceTracker(sim_params, self.env) # replicate what tradesim does by going through every minute or day # of the simulation and processing open orders each time diff --git a/tests/test_history.py b/tests/test_history.py index b205e34a..bb59ee03 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -108,11 +108,18 @@ class HistoryTestCaseBase(TestCase): cls.create_data() + @classmethod + def tearDownClass(cls): + del cls.adj_reader + cls.tempdir.cleanup() + def setUp(self): self.create_data_portal() - @classmethod - def create_data_portal(cls): + def tearDown(self): + del self.data_portal + + def create_data_portal(self): raise NotImplementedError() @classmethod @@ -169,10 +176,6 @@ class HistoryTestCaseBase(TestCase): } }) - @classmethod - def tearDownClass(cls): - cls.tempdir.cleanup() - @classmethod def create_adjustments_reader(cls): path = cls.tempdir.getpath("test_adjustments.db") @@ -470,12 +473,11 @@ MINUTE_FIELD_INFO = { class MinuteEquityHistoryTestCase(HistoryTestCaseBase): - @classmethod - def create_data_portal(cls): - cls.data_portal = DataPortal( - cls.env, - equity_minute_reader=BcolzMinuteBarReader(cls.tempdir.path), - adjustment_reader=cls.adj_reader + def create_data_portal(self): + self.data_portal = DataPortal( + self.env, + equity_minute_reader=BcolzMinuteBarReader(self.tempdir.path), + adjustment_reader=self.adj_reader ) @classmethod @@ -1016,15 +1018,14 @@ class MinuteEquityHistoryTestCase(HistoryTestCaseBase): class DailyEquityHistoryTestCase(HistoryTestCaseBase): - @classmethod - def create_data_portal(cls): - daily_path = cls.tempdir.getpath("testdaily.bcolz") + def create_data_portal(self): + daily_path = self.tempdir.getpath("testdaily.bcolz") - cls.data_portal = DataPortal( - cls.env, + self.data_portal = DataPortal( + self.env, equity_daily_reader=BcolzDailyBarReader(daily_path), - equity_minute_reader=BcolzMinuteBarReader(cls.tempdir.path), - adjustment_reader=cls.adj_reader + equity_minute_reader=BcolzMinuteBarReader(self.tempdir.path), + adjustment_reader=self.adj_reader ) @classmethod diff --git a/tests/test_memoize.py b/tests/test_memoize.py index 75996e05..c1621249 100644 --- a/tests/test_memoize.py +++ b/tests/test_memoize.py @@ -1,6 +1,8 @@ """ Tests for zipline.utils.memoize. """ +from collections import defaultdict +import gc from unittest import TestCase from zipline.utils.memoize import remember_last @@ -32,3 +34,63 @@ class TestRememberLast(TestCase): # Calling the old value should still increment the counter. self.assertEqual((func(1), call_count[0]), (1, 3)) self.assertEqual((func(1), call_count[0]), (1, 3)) + + def test_remember_last_method(self): + call_count = defaultdict(int) + + class clz(object): + @remember_last + def func(self, x): + call_count[(self, x)] += 1 + return x + + inst1 = clz() + self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 1})) + + # Calling again with the same argument should just re-use the old + # value, which means func shouldn't get called again. + self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 1})) + + # Calling with a new value should increment the counter. + self.assertEqual((inst1.func(2), call_count), (2, {(inst1, 1): 1, + (inst1, 2): 1})) + self.assertEqual((inst1.func(2), call_count), (2, {(inst1, 1): 1, + (inst1, 2): 1})) + + # Calling the old value should still increment the counter. + self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 2, + (inst1, 2): 1})) + self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 2, + (inst1, 2): 1})) + + inst2 = clz() + self.assertEqual((inst2.func(1), call_count), + (1, {(inst1, 1): 2, (inst1, 2): 1, + (inst2, 1): 1})) + self.assertEqual((inst2.func(1), call_count), + (1, {(inst1, 1): 2, (inst1, 2): 1, + (inst2, 1): 1})) + + self.assertEqual((inst2.func(2), call_count), + (2, {(inst1, 1): 2, (inst1, 2): 1, + (inst2, 1): 1, (inst2, 2): 1})) + self.assertEqual((inst2.func(2), call_count), + (2, {(inst1, 1): 2, (inst1, 2): 1, + (inst2, 1): 1, (inst2, 2): 1})) + + self.assertEqual((inst2.func(1), call_count), + (1, {(inst1, 1): 2, (inst1, 2): 1, + (inst2, 1): 2, (inst2, 2): 1})) + self.assertEqual((inst2.func(1), call_count), + (1, {(inst1, 1): 2, (inst1, 2): 1, + (inst2, 1): 2, (inst2, 2): 1})) + + # Remove the above references to the instances and ensure that + # remember_last has not made its own. + del inst1, inst2 + call_count.clear() + while gc.collect(): + pass + + self.assertFalse([inst for inst in gc.get_objects() + if type(inst) == clz]) diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index 11b3886e..4174ea52 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -50,7 +50,7 @@ from zipline.utils.serialization_utils import ( loads_with_persistent_ids, dumps_with_persistent_ids ) from zipline.testing.core import create_data_portal_from_trade_history, \ - create_empty_splits_mergers_frame, FakeDataPortal + create_empty_splits_mergers_frame logger = logging.getLogger('Test Perf Tracking') @@ -167,7 +167,7 @@ def calculate_results(sim_params, splits = splits or {} commissions = commissions or {} - perf_tracker = perf.PerformanceTracker(sim_params, env, data_portal) + perf_tracker = perf.PerformanceTracker(sim_params, env) results = [] @@ -189,8 +189,10 @@ def calculate_results(sim_params, except KeyError: pass - msg = perf_tracker.handle_market_close_daily(date) - perf_tracker.position_tracker.sync_last_sale_prices(date, False) + msg = perf_tracker.handle_market_close_daily(date, data_portal) + perf_tracker.position_tracker.sync_last_sale_prices( + date, False, data_portal, + ) msg['account'] = perf_tracker.get_account(True) results.append(copy.deepcopy(msg)) return results @@ -265,9 +267,7 @@ class TestSplitPerformance(unittest.TestCase): def test_multiple_splits(self): # if multiple positions all have splits at the same time, verify that # the total leftover cash is correct - perf_tracker = perf.PerformanceTracker( - self.sim_params, self.env, FakeDataPortal() - ) + perf_tracker = perf.PerformanceTracker(self.sim_params, self.env) asset1 = self.env.asset_finder.retrieve_asset(1) asset2 = self.env.asset_finder.retrieve_asset(2) @@ -1240,11 +1240,10 @@ class TestPositionPerformance(unittest.TestCase): txn1 = create_txn(self.asset1, trades_1[0].dt, 10.0, 100) txn2 = create_txn(self.asset2, trades_1[0].dt, 10.0, -100) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp.position_tracker = pt pt.execute_transaction(txn1) pp.handle_execution(txn1) @@ -1252,7 +1251,7 @@ class TestPositionPerformance(unittest.TestCase): pp.handle_execution(txn2) dt = trades_1[-2].dt - pt.sync_last_sale_prices(dt, False) + pt.sync_last_sale_prices(dt, False, data_portal) pp.calculate_performance() @@ -1280,7 +1279,7 @@ class TestPositionPerformance(unittest.TestCase): net_liquidation=1000.0) dt = trades_1[-1].dt - pt.sync_last_sale_prices(dt, False) + pt.sync_last_sale_prices(dt, False, data_portal) pp.calculate_performance() @@ -1333,11 +1332,10 @@ class TestPositionPerformance(unittest.TestCase): self.sim_params, {1: trades}) txn = create_txn(self.asset1, trades[1].dt, 10.0, 1000) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp.position_tracker = pt pt.execute_transaction(txn) @@ -1355,7 +1353,7 @@ class TestPositionPerformance(unittest.TestCase): shorts_count=0) # Validate that the account attributes were updated. - pt.sync_last_sale_prices(trades[-2].dt, False) + pt.sync_last_sale_prices(trades[-2].dt, False, data_portal) # Validate that the account attributes were updated. account = pp.as_account() @@ -1373,7 +1371,7 @@ class TestPositionPerformance(unittest.TestCase): net_liquidation=1000.0) # now simulate a price jump to $11 - pt.sync_last_sale_prices(trades[-1].dt, False) + pt.sync_last_sale_prices(trades[-1].dt, False, data_portal) pp.calculate_performance() @@ -1425,11 +1423,10 @@ class TestPositionPerformance(unittest.TestCase): self.sim_params, {1: trades}) txn = create_txn(self.asset1, trades[1].dt, 10.0, 100) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, self.sim_params.data_frequency, - data_portal, period_open=self.sim_params.period_start, period_close=self.sim_params.period_end) pp.position_tracker = pt @@ -1444,7 +1441,7 @@ class TestPositionPerformance(unittest.TestCase): # stocks with a last sale price of 0. self.assertEqual(pp.positions[1].last_sale_price, 10.0) - pt.sync_last_sale_prices(trades[-1].dt, False) + pt.sync_last_sale_prices(trades[-1].dt, False, data_portal) pp.calculate_performance() @@ -1544,18 +1541,17 @@ single short-sale transaction""" {1: trades}) txn = create_txn(self.asset1, trades[1].dt, 10.0, -100) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod( 1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp.position_tracker = pt pt.execute_transaction(txn) pp.handle_execution(txn) - pt.sync_last_sale_prices(trades_1[-1].dt, False) + pt.sync_last_sale_prices(trades_1[-1].dt, False, data_portal) pp.calculate_performance() @@ -1611,7 +1607,7 @@ single short-sale transaction""" # simulate a rollover to a new period pp.rollover() - pt.sync_last_sale_prices(trades[-1].dt, False) + pt.sync_last_sale_prices(trades[-1].dt, False, data_portal) pp.calculate_performance() @@ -1665,17 +1661,16 @@ single short-sale transaction""" ) # now run a performance period encompassing the entire trade sample. - ptTotal = perf.PositionTracker(self.env.asset_finder, data_portal, + ptTotal = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) ppTotal.position_tracker = pt ptTotal.execute_transaction(txn) ppTotal.handle_execution(txn) - ptTotal.sync_last_sale_prices(trades[-1].dt, False) + ptTotal.sync_last_sale_prices(trades[-1].dt, False, data_portal) ppTotal.calculate_performance() @@ -1778,11 +1773,10 @@ cost of sole txn in test" ) txn = create_txn(self.asset3, trades[1].dt, 10.0, 1) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp.position_tracker = pt pt.execute_transaction(txn) @@ -1795,7 +1789,7 @@ cost of sole txn in test" # stocks with a last sale price of 0. self.assertEqual(pp.positions[3].last_sale_price, 10.0) - pt.sync_last_sale_prices(trades[-1].dt, False) + pt.sync_last_sale_prices(trades[-1].dt, False, data_portal) pp.calculate_performance() self.assertEqual( @@ -1899,17 +1893,16 @@ single short-sale transaction""" trades_1 = trades[:-2] txn = create_txn(self.asset3, trades[0].dt, 10.0, -1) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp.position_tracker = pt pt.execute_transaction(txn) pp.handle_execution(txn) - pt.sync_last_sale_prices(trades[-3].dt, False) + pt.sync_last_sale_prices(trades[-3].dt, False, data_portal) pp.calculate_performance() self.assertEqual( @@ -1969,7 +1962,7 @@ single short-sale transaction""" # simulate a rollover to a new period pp.rollover() - pt.sync_last_sale_prices(trades_2[-1].dt, False) + pt.sync_last_sale_prices(trades_2[-1].dt, False, data_portal) pp.calculate_performance() self.assertEqual( @@ -2027,21 +2020,20 @@ single short-sale transaction""" ) # now run a performance period encompassing the entire trade sample. - ptTotal = perf.PositionTracker(self.env.asset_finder, data_portal, + ptTotal = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) ppTotal.position_tracker = ptTotal for trade in trades_1: - ptTotal.sync_last_sale_prices(trade.dt, False) + ptTotal.sync_last_sale_prices(trade.dt, False, data_portal) ptTotal.execute_transaction(txn) ppTotal.handle_execution(txn) for trade in trades_2: - ptTotal.sync_last_sale_prices(trade.dt, False) + ptTotal.sync_last_sale_prices(trade.dt, False, data_portal) ppTotal.calculate_performance() @@ -2144,11 +2136,10 @@ trade after cover""" short_txn = create_txn(self.asset1, trades[1].dt, 10.0, -100) cover_txn = create_txn(self.asset1, trades[6].dt, 7.0, 100) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp.position_tracker = pt pt.execute_transaction(short_txn) @@ -2156,7 +2147,7 @@ trade after cover""" pt.execute_transaction(cover_txn) pp.handle_execution(cover_txn) - pt.sync_last_sale_prices(trades[-1].dt, False) + pt.sync_last_sale_prices(trades[-1].dt, False, data_portal) pp.calculate_performance() @@ -2231,13 +2222,12 @@ shares in position" self.sim_params, {1: trades}) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod( 1000.0, self.env.asset_finder, self.sim_params.data_frequency, - data_portal, period_open=self.sim_params.period_start, period_close=self.sim_params.trading_days[-1] ) @@ -2264,7 +2254,7 @@ shares in position" "should have a cost basis of 11" ) - pt.sync_last_sale_prices(dt, False) + pt.sync_last_sale_prices(dt, False, data_portal) pp.calculate_performance() @@ -2281,7 +2271,7 @@ shares in position" pp.handle_execution(sale_txn) dt = down_tick.dt - pt.sync_last_sale_prices(dt, False) + pt.sync_last_sale_prices(dt, False, data_portal) pp.calculate_performance() self.assertEqual( @@ -2299,11 +2289,10 @@ shares in position" self.assertEqual(pp.pnl, -800, "this period goes from +400 to -400") - pt3 = perf.PositionTracker(self.env.asset_finder, data_portal, + pt3 = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp3 = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp3.position_tracker = pt3 average_cost = 0 @@ -2317,7 +2306,7 @@ shares in position" pp3.handle_execution(sale_txn) trades.append(down_tick) - pt3.sync_last_sale_prices(trades[-1].dt, False) + pt3.sync_last_sale_prices(trades[-1].dt, False, data_portal) pp3.calculate_performance() self.assertEqual( @@ -2351,18 +2340,11 @@ shares in position" ) cost_bases = [10, 10, 0, 8, 9, 9, 13, 13.5] - trades = factory.create_trade_history(*history_args) transactions = factory.create_txn_history(*history_args) - data_portal = create_data_portal_from_trade_history( - self.env, - self.tempdir, - self.sim_params, - {1: trades}) - - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) - pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, data_portal, + pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, self.sim_params.data_frequency) pp.position_tracker = pt @@ -2413,22 +2395,8 @@ class TestPositionTracker(unittest.TestCase): sim_params = factory.create_simulation_parameters( num_days=4, env=self.env ) - trades = factory.create_trade_history( - 1, - [10, 10, 10, 11], - [100, 100, 100, 100], - oneday, - sim_params, - env=self.env - ) - data_portal = create_data_portal_from_trade_history( - self.env, - self.tempdir, - sim_params, - {1: trades}) - - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, sim_params.data_frequency) pos_stats = pt.stats() @@ -2450,7 +2418,7 @@ class TestPositionTracker(unittest.TestCase): self.assertNotIsInstance(val, (bool, np.bool_)) def test_position_values_and_exposures(self): - pt = perf.PositionTracker(self.env.asset_finder, None, None) + pt = perf.PositionTracker(self.env.asset_finder, None) dt = pd.Timestamp("1984/03/06 3:00PM") pos1 = perf.Position(1, amount=np.float64(10.0), last_sale_date=dt, last_sale_price=10) @@ -2482,7 +2450,7 @@ class TestPositionTracker(unittest.TestCase): self.assertEqual(100 - 200 + 300000 - 400000, pos_stats.net_exposure) def test_update_positions(self): - pt = perf.PositionTracker(self.env.asset_finder, None, None) + pt = perf.PositionTracker(self.env.asset_finder, None) dt = pd.Timestamp("2014/01/01 3:00PM") pos1 = perf.Position(1, amount=np.float64(10.0), last_sale_date=dt, last_sale_price=10) diff --git a/tests/utils/daily_bar_writer.py b/tests/utils/daily_bar_writer.py index 344e664f..df69c55c 100644 --- a/tests/utils/daily_bar_writer.py +++ b/tests/utils/daily_bar_writer.py @@ -1,6 +1,7 @@ from numpy import ( float64, - uint32 + uint32, + int64, ) from bcolz import ctable @@ -37,8 +38,9 @@ class DailyBarWriterFromDataFrames(BcolzDailyBarWriter): return array.astype(uint32) elif colname == 'day': nanos_per_second = (1000 * 1000 * 1000) - self.check_uint_safe(arrmax.view(int) / nanos_per_second, colname) - return (array.view(int) / nanos_per_second).astype(uint32) + self.check_uint_safe(arrmax.view(int64) / nanos_per_second, + colname) + return (array.view(int64) / nanos_per_second).astype(uint32) @staticmethod def check_uint_safe(value, colname): diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 8e204710..98f75b50 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -492,7 +492,6 @@ class TradingAlgorithm(object): self.perf_tracker = PerformanceTracker( sim_params=self.sim_params, env=self.trading_environment, - data_portal=self.data_portal ) # Set the dt initially to the period start by forcing it to change. @@ -603,14 +602,17 @@ class TradingAlgorithm(object): # Create zipline and loop through simulated_trading. # Each iteration returns a perf dictionary - perfs = [] - for perf in self.get_generator(): - perfs.append(perf) + try: + perfs = [] + for perf in self.get_generator(): + perfs.append(perf) - # convert perf dict to pandas dataframe - daily_stats = self._create_daily_stats(perfs) + # convert perf dict to pandas dataframe + daily_stats = self._create_daily_stats(perfs) - self.analyze(daily_stats) + self.analyze(daily_stats) + finally: + self.data_portal = None return daily_stats @@ -1057,7 +1059,7 @@ class TradingAlgorithm(object): def updated_portfolio(self): if self.portfolio_needs_update: self.perf_tracker.position_tracker.sync_last_sale_prices( - self.datetime, self._in_before_trading_start) + self.datetime, self._in_before_trading_start, self.data_portal) self._portfolio = \ self.perf_tracker.get_portfolio(self.performance_needs_update) self.portfolio_needs_update = False @@ -1071,7 +1073,7 @@ class TradingAlgorithm(object): def updated_account(self): if self.account_needs_update: self.perf_tracker.position_tracker.sync_last_sale_prices( - self.datetime, self._in_before_trading_start) + self.datetime, self._in_before_trading_start, self.data_portal) self._account = \ self.perf_tracker.get_account(self.performance_needs_update) self.account_needs_update = False diff --git a/zipline/data/data_portal.py b/zipline/data/data_portal.py index 9db60e5c..f8598c38 100644 --- a/zipline/data/data_portal.py +++ b/zipline/data/data_portal.py @@ -31,13 +31,12 @@ from zipline.data.us_equity_loader import ( ) from zipline.utils import tradingcalendar -from zipline.utils.compat import lru_cache from zipline.utils.math_utils import ( nansum, nanmean, nanstd ) -from zipline.utils.memoize import remember_last +from zipline.utils.memoize import remember_last, weak_lru_cache from zipline.errors import ( NoTradeDataAvailableTooEarly, NoTradeDataAvailableTooLate, @@ -1091,10 +1090,6 @@ class DataPortal(object): return data - @remember_last - def _get_market_minutes_for_day(self, end_date): - return self.env.market_minutes_for_day(pd.Timestamp(end_date)) - def _get_history_daily_window_equities( self, assets, days_for_window, end_dt, field_to_use): ends_at_midnight = end_dt.hour == 0 and end_dt.minute == 0 @@ -1686,7 +1681,7 @@ class DataPortal(object): else: return [assets] if isinstance(assets, Asset) else [] - @lru_cache(20) + @weak_lru_cache(20) def _get_minute_count_for_transform(self, ending_minute, days_count): # cache size picked somewhat loosely. this code exists purely to # handle deprecated API. diff --git a/zipline/data/minute_bars.py b/zipline/data/minute_bars.py index fbee5f4c..09c9d994 100644 --- a/zipline/data/minute_bars.py +++ b/zipline/data/minute_bars.py @@ -163,10 +163,10 @@ class BcolzMinuteBarMetadata(object): 'first_trading_day': str(self.first_trading_day.date()), 'market_opens': self.market_opens.values. astype('datetime64[m]'). - astype(int).tolist(), + astype(np.int64).tolist(), 'market_closes': self.market_closes.values. astype('datetime64[m]'). - astype(int).tolist(), + astype(np.int64).tolist(), 'ohlc_ratio': self.ohlc_ratio, } with open(self.metadata_path(rootdir), 'w+') as fp: @@ -603,10 +603,10 @@ class BcolzMinuteBarReader(object): self._market_opens = metadata.market_opens self._market_open_values = metadata.market_opens.values.\ - astype('datetime64[m]').astype(int) + astype('datetime64[m]').astype(np.int64) self._market_closes = metadata.market_closes self._market_close_values = metadata.market_closes.values.\ - astype('datetime64[m]').astype(int) + astype('datetime64[m]').astype(np.int64) self._ohlc_inverse = 1.0 / metadata.ohlc_ratio @@ -643,7 +643,7 @@ class BcolzMinuteBarReader(object): """ market_opens = self._market_opens.values.astype('datetime64[m]') market_closes = self._market_closes.values.astype('datetime64[m]') - minutes_per_day = (market_closes - market_opens).astype(int) + minutes_per_day = (market_closes - market_opens).astype(np.int64) early_indices = np.where( minutes_per_day != US_EQUITIES_MINUTES_PER_DAY - 1)[0] regular_closes = market_opens[early_indices] + timedelta64( diff --git a/zipline/finance/performance/period.py b/zipline/finance/performance/period.py index 1488d666..23a1864c 100644 --- a/zipline/finance/performance/period.py +++ b/zipline/finance/performance/period.py @@ -133,7 +133,6 @@ class PerformancePeriod(object): starting_cash, asset_finder, data_frequency, - data_portal, period_open=None, period_close=None, keep_transactions=True, @@ -144,8 +143,6 @@ class PerformancePeriod(object): self.asset_finder = asset_finder self.data_frequency = data_frequency - self._data_portal = data_portal - self.period_open = period_open self.period_close = period_close diff --git a/zipline/finance/performance/position_tracker.py b/zipline/finance/performance/position_tracker.py index 88d68dc5..ea1600a1 100644 --- a/zipline/finance/performance/position_tracker.py +++ b/zipline/finance/performance/position_tracker.py @@ -119,14 +119,9 @@ def calc_gross_value(long_value, short_value): class PositionTracker(object): - def __init__(self, asset_finder, data_portal, data_frequency): + def __init__(self, asset_finder, data_frequency): self.asset_finder = asset_finder - # FIXME really want to avoid storing a data portal here, - # but the path to get to maybe_create_close_position_transaction - # is long and tortuous - self._data_portal = data_portal - # sid => position object self.positions = positiondict() # Arrays for quick calculations of positions value @@ -316,12 +311,12 @@ class PositionTracker(object): return net_cash_payment - def maybe_create_close_position_transaction(self, asset, dt): + def maybe_create_close_position_transaction(self, asset, dt, data_portal): if not self.positions.get(asset): return None amount = self.positions.get(asset).amount - price = self._data_portal.get_spot_value( + price = data_portal.get_spot_value( asset, 'price', dt, self.data_frequency) # Get the last traded price if price is no longer available @@ -372,8 +367,8 @@ class PositionTracker(object): positions.append(pos.to_dict()) return positions - def sync_last_sale_prices(self, dt, handle_non_market_minutes): - data_portal = self._data_portal + def sync_last_sale_prices(self, dt, handle_non_market_minutes, + data_portal): if not handle_non_market_minutes: for asset, position in iteritems(self.positions): last_sale_price = data_portal.get_spot_value( diff --git a/zipline/finance/performance/tracker.py b/zipline/finance/performance/tracker.py index 6307416b..ff98ebb1 100644 --- a/zipline/finance/performance/tracker.py +++ b/zipline/finance/performance/tracker.py @@ -78,7 +78,7 @@ class PerformanceTracker(object): """ Tracks the performance of the algorithm. """ - def __init__(self, sim_params, env, data_portal): + def __init__(self, sim_params, env): self.sim_params = sim_params self.env = env @@ -101,15 +101,8 @@ class PerformanceTracker(object): self.trading_days = all_trading_days[mask] - self._data_portal = data_portal - if data_portal is not None: - self._adjustment_reader = data_portal._adjustment_reader - else: - self._adjustment_reader = None - self.position_tracker = PositionTracker( asset_finder=env.asset_finder, - data_portal=data_portal, data_frequency=self.sim_params.data_frequency) if self.emission_rate == 'daily': @@ -132,7 +125,6 @@ class PerformanceTracker(object): # initial cash is your capital base. starting_cash=self.capital_base, data_frequency=self.sim_params.data_frequency, - data_portal=data_portal, # the cumulative period will be calculated over the entire test. period_open=self.period_start, period_close=self.period_end, @@ -152,7 +144,6 @@ class PerformanceTracker(object): # initial cash is your capital base. starting_cash=self.capital_base, data_frequency=self.sim_params.data_frequency, - data_portal=data_portal, # the daily period will be calculated for the market day period_open=self.market_open, period_close=self.market_close, @@ -264,13 +255,13 @@ class PerformanceTracker(object): self.cumulative_performance.handle_commission(cost) self.todays_performance.handle_commission(cost) - def process_close_position(self, asset, dt): + def process_close_position(self, asset, dt, data_portal): txn = self.position_tracker.\ - maybe_create_close_position_transaction(asset, dt) + maybe_create_close_position_transaction(asset, dt, data_portal) if txn: self.process_transaction(txn) - def check_upcoming_dividends(self, next_trading_day): + def check_upcoming_dividends(self, next_trading_day, adjustment_reader): """ Check if we currently own any stocks with dividends whose ex_date is the next trading day. Track how much we should be payed on those @@ -280,7 +271,7 @@ class PerformanceTracker(object): is the next trading day. Apply all such benefits, then recalculate performance. """ - if self._adjustment_reader is None: + if adjustment_reader is None: return position_tracker = self.position_tracker held_sids = set(position_tracker.positions) @@ -291,10 +282,10 @@ class PerformanceTracker(object): if held_sids: asset_finder = self.env.asset_finder - cash_dividends = self._adjustment_reader.\ + cash_dividends = adjustment_reader.\ get_dividends_with_ex_date(held_sids, next_trading_day, asset_finder) - stock_dividends = self._adjustment_reader.\ + stock_dividends = adjustment_reader.\ get_stock_dividends_with_ex_date(held_sids, next_trading_day, asset_finder) @@ -310,7 +301,7 @@ class PerformanceTracker(object): self.cumulative_performance.handle_dividends_paid(net_cash_payment) self.todays_performance.handle_dividends_paid(net_cash_payment) - def handle_minute_close(self, dt): + def handle_minute_close(self, dt, data_portal): """ Handles the close of the given minute. This includes handling market-close functions if the given minute is the end of the market @@ -327,7 +318,7 @@ class PerformanceTracker(object): A tuple of the minute perf packet and daily perf packet. If the market day has not ended, the daily perf packet is None. """ - self.position_tracker.sync_last_sale_prices(dt, False) + self.position_tracker.sync_last_sale_prices(dt, False, data_portal) self.update_performance() todays_date = normalize_date(dt) account = self.get_account(False) @@ -346,16 +337,18 @@ class PerformanceTracker(object): # if this is the close, update dividends for the next day. # Return the performance tuple if dt == self.market_close: - return minute_packet, self._handle_market_close(todays_date) + return minute_packet, self._handle_market_close( + todays_date, data_portal._adjustment_reader, + ) else: return minute_packet, None - def handle_market_close_daily(self, dt): + def handle_market_close_daily(self, dt, data_portal): """ Function called after handle_data when running with daily emission rate. """ - self.position_tracker.sync_last_sale_prices(dt, False) + self.position_tracker.sync_last_sale_prices(dt, False, data_portal) self.update_performance() completed_date = self.day account = self.get_account(False) @@ -368,11 +361,12 @@ class PerformanceTracker(object): benchmark_value, account.leverage) - daily_packet = self._handle_market_close(completed_date) - + daily_packet = self._handle_market_close( + completed_date, data_portal._adjustment_reader, + ) return daily_packet - def _handle_market_close(self, completed_date): + def _handle_market_close(self, completed_date, adjustment_reader): # increment the day counter before we move markers forward. self.day_count += 1.0 @@ -406,7 +400,8 @@ class PerformanceTracker(object): return daily_update # Check for any dividends, then return the daily perf packet - self.check_upcoming_dividends(next_trading_day=next_trading_day) + self.check_upcoming_dividends(next_trading_day=next_trading_day, + adjustment_reader=adjustment_reader) return daily_update def handle_simulation_end(self): diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index b9200a5b..33a9c432 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -97,20 +97,9 @@ class AlgorithmSimulator(object): Main generator work loop. """ algo = self.algo - algo.data_portal = self.data_portal - handle_data = algo.event_manager.handle_data - current_data = self.current_data - data_portal = self.data_portal - - # can't cache a pointer to algo.perf_tracker because we're not - # guaranteed that the algo doesn't swap out perf trackers during - # its lifetime. - # likewise, we can't cache a pointer to the blotter. - - algo.perf_tracker.position_tracker.data_portal = data_portal - - def every_bar(dt_to_use): + def every_bar(dt_to_use, current_data=self.current_data, + handle_data=algo.event_manager.handle_data): # called every tick (minute or day). self.simulation_dt = dt_to_use @@ -152,7 +141,8 @@ class AlgorithmSimulator(object): self.algo.account_needs_update = True self.algo.performance_needs_update = True - def once_a_day(midnight_dt): + def once_a_day(midnight_dt, current_data=self.current_data, + data_portal=self.data_portal): # Get the positions before updating the date so that prices are # fetched for trading close instead of midnight positions = algo.perf_tracker.position_tracker.positions @@ -183,11 +173,15 @@ class AlgorithmSimulator(object): # call before trading start algo.before_trading_start(current_data) - def handle_benchmark(date): + def handle_benchmark(date, benchmark_source=self.benchmark_source): algo.perf_tracker.all_benchmark_returns[date] = \ - self.benchmark_source.get_value(date) + benchmark_source.get_value(date) + + def on_exit(): + self.benchmark_source = self.current_data = self.data_portal = None with ExitStack() as stack: + stack.callback(on_exit) stack.enter_context(self.processor) stack.enter_context(ZiplineAPI(self.algo)) @@ -245,8 +239,9 @@ class AlgorithmSimulator(object): assets_to_clear = \ [asset for asset in position_assets if past_auto_close_date(asset)] perf_tracker = algo.perf_tracker + data_portal = self.data_portal for asset in assets_to_clear: - perf_tracker.process_close_position(asset, dt) + perf_tracker.process_close_position(asset, dt, data_portal) # Remove open orders for any sids that have reached their # auto_close_date. @@ -257,23 +252,25 @@ class AlgorithmSimulator(object): for asset in assets_to_cancel: blotter.cancel_all_orders_for_asset(asset) - @staticmethod - def _get_daily_message(dt, algo, perf_tracker): + def _get_daily_message(self, dt, algo, perf_tracker): """ Get a perf message for the given datetime. """ - perf_message = perf_tracker.handle_market_close_daily(dt) + perf_message = perf_tracker.handle_market_close_daily( + dt, self.data_portal, + ) perf_message['daily_perf']['recorded_vars'] = algo.recorded_vars return perf_message - @staticmethod - def _get_minute_message(dt, algo, perf_tracker): + def _get_minute_message(self, dt, algo, perf_tracker): """ Get a perf message for the given datetime. """ rvars = algo.recorded_vars - minute_message, daily_message = perf_tracker.handle_minute_close(dt) + minute_message, daily_message = perf_tracker.handle_minute_close( + dt, self.data_portal, + ) minute_message['minute_perf']['recorded_vars'] = rvars if daily_message: diff --git a/zipline/protocol.py b/zipline/protocol.py index da0d2397..9db94914 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -16,9 +16,8 @@ import pandas as pd from .utils.enum import enum -from zipline._protocol import BarData as _BarData +from zipline._protocol import BarData # noqa -BarData = _BarData # Datasource type should completely determine the other fields of a # message with its type. diff --git a/zipline/testing/core.py b/zipline/testing/core.py index 4a236a3f..8a99b8d2 100644 --- a/zipline/testing/core.py +++ b/zipline/testing/core.py @@ -45,7 +45,8 @@ from zipline.utils.tradingcalendar import trading_days import numpy as np from numpy import ( float64, - uint32 + uint32, + int64, ) @@ -456,8 +457,9 @@ def make_trade_data_for_asset_info(dates, sids = asset_info.keys() date_field = 'day' if frequency == 'daily' else 'dt' - price_sid_deltas = np.arange(len(sids), dtype=float) * price_step_by_sid - price_date_deltas = np.arange(len(dates), dtype=float) * price_step_by_date + price_sid_deltas = np.arange(len(sids), dtype=float64) * price_step_by_sid + price_date_deltas = (np.arange(len(dates), dtype=float64) * + price_step_by_date) prices = (price_sid_deltas + price_date_deltas[:, None]) + price_start volume_sid_deltas = np.arange(len(sids)) * volume_step_by_sid @@ -723,8 +725,9 @@ class DailyBarWriterFromDataFrames(BcolzDailyBarWriter): return array.astype(uint32) elif colname == 'day': nanos_per_second = (1000 * 1000 * 1000) - self.check_uint_safe(arrmax.view(int) / nanos_per_second, colname) - return (array.view(int) / nanos_per_second).astype(uint32) + self.check_uint_safe(arrmax.view(int64) / nanos_per_second, + colname) + return (array.view(int64) / nanos_per_second).astype(uint32) @staticmethod def check_uint_safe(value, colname): @@ -1198,8 +1201,8 @@ def create_mock_adjustments(tempdir, days, splits=None, dividends=None, 'pay_date': np.array([], dtype='datetime64[ns]'), 'record_date': np.array([], dtype='datetime64[ns]'), 'declared_date': np.array([], dtype='datetime64[ns]'), - 'amount': np.array([], dtype=float), - 'sid': np.array([], dtype=int), + 'amount': np.array([], dtype=float64), + 'sid': np.array([], dtype=int64), } dividends = pd.DataFrame( data, @@ -1360,9 +1363,9 @@ def create_empty_splits_mergers_frame(): return pd.DataFrame( { # Hackery to make the dtypes correct on an empty frame. - 'effective_date': np.array([], dtype=int), - 'ratio': np.array([], dtype=float), - 'sid': np.array([], dtype=int), + 'effective_date': np.array([], dtype=int64), + 'ratio': np.array([], dtype=float64), + 'sid': np.array([], dtype=int64), }, index=pd.DatetimeIndex([]), columns=['effective_date', 'ratio', 'sid'], diff --git a/zipline/utils/compat.py b/zipline/utils/compat.py deleted file mode 100644 index d0a037c9..00000000 --- a/zipline/utils/compat.py +++ /dev/null @@ -1,7 +0,0 @@ -from six import PY2 - - -if PY2: - from functools32 import lru_cache # noqa -else: - from functools import lru_cache # noqa diff --git a/zipline/utils/memoize.py b/zipline/utils/memoize.py index 1b6a4bf1..f0c77f3d 100644 --- a/zipline/utils/memoize.py +++ b/zipline/utils/memoize.py @@ -1,8 +1,13 @@ """ Tools for memoization of function results. """ -from zipline.utils.compat import lru_cache -from weakref import WeakKeyDictionary +from collections import OrderedDict, Sequence +from functools import wraps +from itertools import compress +from weakref import WeakKeyDictionary, ref + +from six.moves._thread import allocate_lock as Lock +from toolz.sandbox import unzip class lazyval(object): @@ -84,4 +89,211 @@ class classlazyval(lazyval): return super(classlazyval, self).__get__(owner, owner) -remember_last = lru_cache(1) +def _weak_lru_cache(maxsize=100): + """ + Users should only access the lru_cache through its public API: + cache_info, cache_clear + The internals of the lru_cache are encapsulated for thread safety and + to allow the implementation to change. + """ + def decorating_function( + user_function, tuple=tuple, sorted=sorted, len=len, + KeyError=KeyError): + + hits, misses = [0], [0] + kwd_mark = (object(),) # separates positional and keyword args + lock = Lock() # needed because OrderedDict isn't threadsafe + + if maxsize is None: + cache = _WeakArgsDict() # cache without ordering or size limit + + @wraps(user_function) + def wrapper(*args, **kwds): + key = args + if kwds: + key += kwd_mark + tuple(sorted(kwds.items())) + try: + result = cache[key] + hits[0] += 1 + return result + except KeyError: + pass + result = user_function(*args, **kwds) + cache[key] = result + misses[0] += 1 + return result + else: + # ordered least recent to most recent + cache = _WeakArgsOrderedDict() + cache_popitem = cache.popitem + cache_renew = cache.move_to_end + + @wraps(user_function) + def wrapper(*args, **kwds): + key = args + if kwds: + key += kwd_mark + tuple(sorted(kwds.items())) + with lock: + try: + result = cache[key] + cache_renew(key) # record recent use of this key + hits[0] += 1 + return result + except KeyError: + pass + result = user_function(*args, **kwds) + with lock: + cache[key] = result # record recent use of this key + misses[0] += 1 + if len(cache) > maxsize: + # purge least recently used cache entry + cache_popitem(False) + return result + + def cache_info(): + """Report cache statistics""" + with lock: + return hits[0], misses[0], maxsize, len(cache) + + def cache_clear(): + """Clear the cache and cache statistics""" + with lock: + cache.clear() + hits[0] = misses[0] = 0 + + wrapper.cache_info = cache_info + wrapper.cache_clear = cache_clear + return wrapper + + return decorating_function + + +class _WeakArgs(Sequence): + """ + Works with _WeakArgsDict to provide a weak cache for function args. + When any of those args are gc'd, the pair is removed from the cache. + """ + def __init__(self, items, dict_remove=None): + def remove(k, selfref=ref(self), dict_remove=dict_remove): + self = selfref() + if self is not None and dict_remove is not None: + dict_remove(self) + + self._items, self._selectors = unzip(self._try_ref(item, remove) + for item in items) + self._items = tuple(self._items) + self._selectors = tuple(self._selectors) + + def __getitem__(self, index): + return self._items[index] + + def __len__(self): + return len(self._items) + + @staticmethod + def _try_ref(item, callback): + try: + return ref(item, callback), True + except TypeError: + return item, False + + @property + def alive(self): + return all(item() is not None + for item in compress(self._items, self._selectors)) + + def __eq__(self, other): + return self._items == other._items + + def __hash__(self): + try: + return self.__hash + except AttributeError: + h = self.__hash = hash(self._items) + return h + + +class _WeakArgsDict(WeakKeyDictionary, object): + def __delitem__(self, key): + del self.data[_WeakArgs(key)] + + def __getitem__(self, key): + return self.data[_WeakArgs(key)] + + def __repr__(self): + return '%s(%r)' % (type(self).__name__, self.data) + + def __setitem__(self, key, value): + self.data[_WeakArgs(key, self._remove)] = value + + def __contains__(self, key): + try: + wr = _WeakArgs(key) + except TypeError: + return False + return wr in self.data + + def pop(self, key, *args): + return self.data.pop(_WeakArgs(key), *args) + + +class _WeakArgsOrderedDict(_WeakArgsDict, object): + def __init__(self): + super(_WeakArgsOrderedDict, self).__init__() + self.data = OrderedDict() + + def popitem(self, last=True): + while True: + key, value = self.data.popitem(last) + if key.alive: + return tuple(key), value + + def move_to_end(self, key): + """Move an existing element to the end. + + Raises KeyError if the element does not exist. + """ + self[key] = self.pop(key) + + +def weak_lru_cache(maxsize=100): + """Weak least-recently-used cache decorator. + + If *maxsize* is set to None, the LRU features are disabled and the cache + can grow without bound. + + Arguments to the cached function must be hashable. Any that are weak- + referenceable will be stored by weak reference. Once any of the args have + been garbage collected, the entry will be removed from the cache. + + View the cache statistics named tuple (hits, misses, maxsize, currsize) + with f.cache_info(). Clear the cache and statistics with f.cache_clear(). + + See: http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used + + """ + class desc(lazyval): + def __get__(self, instance, owner): + if instance is None: + return self + try: + return self._cache[instance] + except KeyError: + inst = ref(instance) + + @_weak_lru_cache(maxsize) + @wraps(self._get) + def wrapper(*args, **kwargs): + return self._get(inst(), *args, **kwargs) + + self._cache[instance] = wrapper + return wrapper + + @_weak_lru_cache(maxsize) + def __call__(self, *args, **kwargs): + return self._get(*args, **kwargs) + + return desc + + +remember_last = weak_lru_cache(1)