diff --git a/etc/requirements.txt b/etc/requirements.txt index a90571c8..7ffa9527 100644 --- a/etc/requirements.txt +++ b/etc/requirements.txt @@ -33,9 +33,6 @@ cyordereddict==0.2.2 # faster array ops. bottleneck==1.0.0 -# lru_cache -functools32==3.2.3.post2;python_version<'3.0' - contextlib2==0.4.0 # networkx requires decorator diff --git a/tests/test_memoize.py b/tests/test_memoize.py index 75996e05..c1621249 100644 --- a/tests/test_memoize.py +++ b/tests/test_memoize.py @@ -1,6 +1,8 @@ """ Tests for zipline.utils.memoize. """ +from collections import defaultdict +import gc from unittest import TestCase from zipline.utils.memoize import remember_last @@ -32,3 +34,63 @@ class TestRememberLast(TestCase): # Calling the old value should still increment the counter. self.assertEqual((func(1), call_count[0]), (1, 3)) self.assertEqual((func(1), call_count[0]), (1, 3)) + + def test_remember_last_method(self): + call_count = defaultdict(int) + + class clz(object): + @remember_last + def func(self, x): + call_count[(self, x)] += 1 + return x + + inst1 = clz() + self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 1})) + + # Calling again with the same argument should just re-use the old + # value, which means func shouldn't get called again. + self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 1})) + + # Calling with a new value should increment the counter. + self.assertEqual((inst1.func(2), call_count), (2, {(inst1, 1): 1, + (inst1, 2): 1})) + self.assertEqual((inst1.func(2), call_count), (2, {(inst1, 1): 1, + (inst1, 2): 1})) + + # Calling the old value should still increment the counter. + self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 2, + (inst1, 2): 1})) + self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 2, + (inst1, 2): 1})) + + inst2 = clz() + self.assertEqual((inst2.func(1), call_count), + (1, {(inst1, 1): 2, (inst1, 2): 1, + (inst2, 1): 1})) + self.assertEqual((inst2.func(1), call_count), + (1, {(inst1, 1): 2, (inst1, 2): 1, + (inst2, 1): 1})) + + self.assertEqual((inst2.func(2), call_count), + (2, {(inst1, 1): 2, (inst1, 2): 1, + (inst2, 1): 1, (inst2, 2): 1})) + self.assertEqual((inst2.func(2), call_count), + (2, {(inst1, 1): 2, (inst1, 2): 1, + (inst2, 1): 1, (inst2, 2): 1})) + + self.assertEqual((inst2.func(1), call_count), + (1, {(inst1, 1): 2, (inst1, 2): 1, + (inst2, 1): 2, (inst2, 2): 1})) + self.assertEqual((inst2.func(1), call_count), + (1, {(inst1, 1): 2, (inst1, 2): 1, + (inst2, 1): 2, (inst2, 2): 1})) + + # Remove the above references to the instances and ensure that + # remember_last has not made its own. + del inst1, inst2 + call_count.clear() + while gc.collect(): + pass + + self.assertFalse([inst for inst in gc.get_objects() + if type(inst) == clz]) diff --git a/zipline/data/data_portal.py b/zipline/data/data_portal.py index 9db60e5c..84c53c05 100644 --- a/zipline/data/data_portal.py +++ b/zipline/data/data_portal.py @@ -31,13 +31,12 @@ from zipline.data.us_equity_loader import ( ) from zipline.utils import tradingcalendar -from zipline.utils.compat import lru_cache from zipline.utils.math_utils import ( nansum, nanmean, nanstd ) -from zipline.utils.memoize import remember_last +from zipline.utils.memoize import remember_last, weak_lru_cache from zipline.errors import ( NoTradeDataAvailableTooEarly, NoTradeDataAvailableTooLate, @@ -1686,7 +1685,7 @@ class DataPortal(object): else: return [assets] if isinstance(assets, Asset) else [] - @lru_cache(20) + @weak_lru_cache(20) def _get_minute_count_for_transform(self, ending_minute, days_count): # cache size picked somewhat loosely. this code exists purely to # handle deprecated API. diff --git a/zipline/utils/compat.py b/zipline/utils/compat.py deleted file mode 100644 index d0a037c9..00000000 --- a/zipline/utils/compat.py +++ /dev/null @@ -1,7 +0,0 @@ -from six import PY2 - - -if PY2: - from functools32 import lru_cache # noqa -else: - from functools import lru_cache # noqa diff --git a/zipline/utils/memoize.py b/zipline/utils/memoize.py index 1b6a4bf1..f0c77f3d 100644 --- a/zipline/utils/memoize.py +++ b/zipline/utils/memoize.py @@ -1,8 +1,13 @@ """ Tools for memoization of function results. """ -from zipline.utils.compat import lru_cache -from weakref import WeakKeyDictionary +from collections import OrderedDict, Sequence +from functools import wraps +from itertools import compress +from weakref import WeakKeyDictionary, ref + +from six.moves._thread import allocate_lock as Lock +from toolz.sandbox import unzip class lazyval(object): @@ -84,4 +89,211 @@ class classlazyval(lazyval): return super(classlazyval, self).__get__(owner, owner) -remember_last = lru_cache(1) +def _weak_lru_cache(maxsize=100): + """ + Users should only access the lru_cache through its public API: + cache_info, cache_clear + The internals of the lru_cache are encapsulated for thread safety and + to allow the implementation to change. + """ + def decorating_function( + user_function, tuple=tuple, sorted=sorted, len=len, + KeyError=KeyError): + + hits, misses = [0], [0] + kwd_mark = (object(),) # separates positional and keyword args + lock = Lock() # needed because OrderedDict isn't threadsafe + + if maxsize is None: + cache = _WeakArgsDict() # cache without ordering or size limit + + @wraps(user_function) + def wrapper(*args, **kwds): + key = args + if kwds: + key += kwd_mark + tuple(sorted(kwds.items())) + try: + result = cache[key] + hits[0] += 1 + return result + except KeyError: + pass + result = user_function(*args, **kwds) + cache[key] = result + misses[0] += 1 + return result + else: + # ordered least recent to most recent + cache = _WeakArgsOrderedDict() + cache_popitem = cache.popitem + cache_renew = cache.move_to_end + + @wraps(user_function) + def wrapper(*args, **kwds): + key = args + if kwds: + key += kwd_mark + tuple(sorted(kwds.items())) + with lock: + try: + result = cache[key] + cache_renew(key) # record recent use of this key + hits[0] += 1 + return result + except KeyError: + pass + result = user_function(*args, **kwds) + with lock: + cache[key] = result # record recent use of this key + misses[0] += 1 + if len(cache) > maxsize: + # purge least recently used cache entry + cache_popitem(False) + return result + + def cache_info(): + """Report cache statistics""" + with lock: + return hits[0], misses[0], maxsize, len(cache) + + def cache_clear(): + """Clear the cache and cache statistics""" + with lock: + cache.clear() + hits[0] = misses[0] = 0 + + wrapper.cache_info = cache_info + wrapper.cache_clear = cache_clear + return wrapper + + return decorating_function + + +class _WeakArgs(Sequence): + """ + Works with _WeakArgsDict to provide a weak cache for function args. + When any of those args are gc'd, the pair is removed from the cache. + """ + def __init__(self, items, dict_remove=None): + def remove(k, selfref=ref(self), dict_remove=dict_remove): + self = selfref() + if self is not None and dict_remove is not None: + dict_remove(self) + + self._items, self._selectors = unzip(self._try_ref(item, remove) + for item in items) + self._items = tuple(self._items) + self._selectors = tuple(self._selectors) + + def __getitem__(self, index): + return self._items[index] + + def __len__(self): + return len(self._items) + + @staticmethod + def _try_ref(item, callback): + try: + return ref(item, callback), True + except TypeError: + return item, False + + @property + def alive(self): + return all(item() is not None + for item in compress(self._items, self._selectors)) + + def __eq__(self, other): + return self._items == other._items + + def __hash__(self): + try: + return self.__hash + except AttributeError: + h = self.__hash = hash(self._items) + return h + + +class _WeakArgsDict(WeakKeyDictionary, object): + def __delitem__(self, key): + del self.data[_WeakArgs(key)] + + def __getitem__(self, key): + return self.data[_WeakArgs(key)] + + def __repr__(self): + return '%s(%r)' % (type(self).__name__, self.data) + + def __setitem__(self, key, value): + self.data[_WeakArgs(key, self._remove)] = value + + def __contains__(self, key): + try: + wr = _WeakArgs(key) + except TypeError: + return False + return wr in self.data + + def pop(self, key, *args): + return self.data.pop(_WeakArgs(key), *args) + + +class _WeakArgsOrderedDict(_WeakArgsDict, object): + def __init__(self): + super(_WeakArgsOrderedDict, self).__init__() + self.data = OrderedDict() + + def popitem(self, last=True): + while True: + key, value = self.data.popitem(last) + if key.alive: + return tuple(key), value + + def move_to_end(self, key): + """Move an existing element to the end. + + Raises KeyError if the element does not exist. + """ + self[key] = self.pop(key) + + +def weak_lru_cache(maxsize=100): + """Weak least-recently-used cache decorator. + + If *maxsize* is set to None, the LRU features are disabled and the cache + can grow without bound. + + Arguments to the cached function must be hashable. Any that are weak- + referenceable will be stored by weak reference. Once any of the args have + been garbage collected, the entry will be removed from the cache. + + View the cache statistics named tuple (hits, misses, maxsize, currsize) + with f.cache_info(). Clear the cache and statistics with f.cache_clear(). + + See: http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used + + """ + class desc(lazyval): + def __get__(self, instance, owner): + if instance is None: + return self + try: + return self._cache[instance] + except KeyError: + inst = ref(instance) + + @_weak_lru_cache(maxsize) + @wraps(self._get) + def wrapper(*args, **kwargs): + return self._get(inst(), *args, **kwargs) + + self._cache[instance] = wrapper + return wrapper + + @_weak_lru_cache(maxsize) + def __call__(self, *args, **kwargs): + return self._get(*args, **kwargs) + + return desc + + +remember_last = weak_lru_cache(1)