PERF: Reimplemented remember_last with a weak_lru_cache

which won't leak instances whose methods have been decorated

(specifically DataPortal instances)

MAINT: Not using functools32 anymore
This commit is contained in:
Richard Frank
2016-04-07 16:35:09 -04:00
parent 32a400a9fb
commit 5254b273b2
5 changed files with 279 additions and 16 deletions
-3
View File
@@ -33,9 +33,6 @@ cyordereddict==0.2.2
# faster array ops.
bottleneck==1.0.0
# lru_cache
functools32==3.2.3.post2;python_version<'3.0'
contextlib2==0.4.0
# networkx requires decorator
+62
View File
@@ -1,6 +1,8 @@
"""
Tests for zipline.utils.memoize.
"""
from collections import defaultdict
import gc
from unittest import TestCase
from zipline.utils.memoize import remember_last
@@ -32,3 +34,63 @@ class TestRememberLast(TestCase):
# Calling the old value should still increment the counter.
self.assertEqual((func(1), call_count[0]), (1, 3))
self.assertEqual((func(1), call_count[0]), (1, 3))
def test_remember_last_method(self):
call_count = defaultdict(int)
class clz(object):
@remember_last
def func(self, x):
call_count[(self, x)] += 1
return x
inst1 = clz()
self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 1}))
# Calling again with the same argument should just re-use the old
# value, which means func shouldn't get called again.
self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 1}))
# Calling with a new value should increment the counter.
self.assertEqual((inst1.func(2), call_count), (2, {(inst1, 1): 1,
(inst1, 2): 1}))
self.assertEqual((inst1.func(2), call_count), (2, {(inst1, 1): 1,
(inst1, 2): 1}))
# Calling the old value should still increment the counter.
self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 2,
(inst1, 2): 1}))
self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 2,
(inst1, 2): 1}))
inst2 = clz()
self.assertEqual((inst2.func(1), call_count),
(1, {(inst1, 1): 2, (inst1, 2): 1,
(inst2, 1): 1}))
self.assertEqual((inst2.func(1), call_count),
(1, {(inst1, 1): 2, (inst1, 2): 1,
(inst2, 1): 1}))
self.assertEqual((inst2.func(2), call_count),
(2, {(inst1, 1): 2, (inst1, 2): 1,
(inst2, 1): 1, (inst2, 2): 1}))
self.assertEqual((inst2.func(2), call_count),
(2, {(inst1, 1): 2, (inst1, 2): 1,
(inst2, 1): 1, (inst2, 2): 1}))
self.assertEqual((inst2.func(1), call_count),
(1, {(inst1, 1): 2, (inst1, 2): 1,
(inst2, 1): 2, (inst2, 2): 1}))
self.assertEqual((inst2.func(1), call_count),
(1, {(inst1, 1): 2, (inst1, 2): 1,
(inst2, 1): 2, (inst2, 2): 1}))
# Remove the above references to the instances and ensure that
# remember_last has not made its own.
del inst1, inst2
call_count.clear()
while gc.collect():
pass
self.assertFalse([inst for inst in gc.get_objects()
if type(inst) == clz])
+2 -3
View File
@@ -31,13 +31,12 @@ from zipline.data.us_equity_loader import (
)
from zipline.utils import tradingcalendar
from zipline.utils.compat import lru_cache
from zipline.utils.math_utils import (
nansum,
nanmean,
nanstd
)
from zipline.utils.memoize import remember_last
from zipline.utils.memoize import remember_last, weak_lru_cache
from zipline.errors import (
NoTradeDataAvailableTooEarly,
NoTradeDataAvailableTooLate,
@@ -1686,7 +1685,7 @@ class DataPortal(object):
else:
return [assets] if isinstance(assets, Asset) else []
@lru_cache(20)
@weak_lru_cache(20)
def _get_minute_count_for_transform(self, ending_minute, days_count):
# cache size picked somewhat loosely. this code exists purely to
# handle deprecated API.
-7
View File
@@ -1,7 +0,0 @@
from six import PY2
if PY2:
from functools32 import lru_cache # noqa
else:
from functools import lru_cache # noqa
+215 -3
View File
@@ -1,8 +1,13 @@
"""
Tools for memoization of function results.
"""
from zipline.utils.compat import lru_cache
from weakref import WeakKeyDictionary
from collections import OrderedDict, Sequence
from functools import wraps
from itertools import compress
from weakref import WeakKeyDictionary, ref
from six.moves._thread import allocate_lock as Lock
from toolz.sandbox import unzip
class lazyval(object):
@@ -84,4 +89,211 @@ class classlazyval(lazyval):
return super(classlazyval, self).__get__(owner, owner)
remember_last = lru_cache(1)
def _weak_lru_cache(maxsize=100):
"""
Users should only access the lru_cache through its public API:
cache_info, cache_clear
The internals of the lru_cache are encapsulated for thread safety and
to allow the implementation to change.
"""
def decorating_function(
user_function, tuple=tuple, sorted=sorted, len=len,
KeyError=KeyError):
hits, misses = [0], [0]
kwd_mark = (object(),) # separates positional and keyword args
lock = Lock() # needed because OrderedDict isn't threadsafe
if maxsize is None:
cache = _WeakArgsDict() # cache without ordering or size limit
@wraps(user_function)
def wrapper(*args, **kwds):
key = args
if kwds:
key += kwd_mark + tuple(sorted(kwds.items()))
try:
result = cache[key]
hits[0] += 1
return result
except KeyError:
pass
result = user_function(*args, **kwds)
cache[key] = result
misses[0] += 1
return result
else:
# ordered least recent to most recent
cache = _WeakArgsOrderedDict()
cache_popitem = cache.popitem
cache_renew = cache.move_to_end
@wraps(user_function)
def wrapper(*args, **kwds):
key = args
if kwds:
key += kwd_mark + tuple(sorted(kwds.items()))
with lock:
try:
result = cache[key]
cache_renew(key) # record recent use of this key
hits[0] += 1
return result
except KeyError:
pass
result = user_function(*args, **kwds)
with lock:
cache[key] = result # record recent use of this key
misses[0] += 1
if len(cache) > maxsize:
# purge least recently used cache entry
cache_popitem(False)
return result
def cache_info():
"""Report cache statistics"""
with lock:
return hits[0], misses[0], maxsize, len(cache)
def cache_clear():
"""Clear the cache and cache statistics"""
with lock:
cache.clear()
hits[0] = misses[0] = 0
wrapper.cache_info = cache_info
wrapper.cache_clear = cache_clear
return wrapper
return decorating_function
class _WeakArgs(Sequence):
"""
Works with _WeakArgsDict to provide a weak cache for function args.
When any of those args are gc'd, the pair is removed from the cache.
"""
def __init__(self, items, dict_remove=None):
def remove(k, selfref=ref(self), dict_remove=dict_remove):
self = selfref()
if self is not None and dict_remove is not None:
dict_remove(self)
self._items, self._selectors = unzip(self._try_ref(item, remove)
for item in items)
self._items = tuple(self._items)
self._selectors = tuple(self._selectors)
def __getitem__(self, index):
return self._items[index]
def __len__(self):
return len(self._items)
@staticmethod
def _try_ref(item, callback):
try:
return ref(item, callback), True
except TypeError:
return item, False
@property
def alive(self):
return all(item() is not None
for item in compress(self._items, self._selectors))
def __eq__(self, other):
return self._items == other._items
def __hash__(self):
try:
return self.__hash
except AttributeError:
h = self.__hash = hash(self._items)
return h
class _WeakArgsDict(WeakKeyDictionary, object):
def __delitem__(self, key):
del self.data[_WeakArgs(key)]
def __getitem__(self, key):
return self.data[_WeakArgs(key)]
def __repr__(self):
return '%s(%r)' % (type(self).__name__, self.data)
def __setitem__(self, key, value):
self.data[_WeakArgs(key, self._remove)] = value
def __contains__(self, key):
try:
wr = _WeakArgs(key)
except TypeError:
return False
return wr in self.data
def pop(self, key, *args):
return self.data.pop(_WeakArgs(key), *args)
class _WeakArgsOrderedDict(_WeakArgsDict, object):
def __init__(self):
super(_WeakArgsOrderedDict, self).__init__()
self.data = OrderedDict()
def popitem(self, last=True):
while True:
key, value = self.data.popitem(last)
if key.alive:
return tuple(key), value
def move_to_end(self, key):
"""Move an existing element to the end.
Raises KeyError if the element does not exist.
"""
self[key] = self.pop(key)
def weak_lru_cache(maxsize=100):
"""Weak least-recently-used cache decorator.
If *maxsize* is set to None, the LRU features are disabled and the cache
can grow without bound.
Arguments to the cached function must be hashable. Any that are weak-
referenceable will be stored by weak reference. Once any of the args have
been garbage collected, the entry will be removed from the cache.
View the cache statistics named tuple (hits, misses, maxsize, currsize)
with f.cache_info(). Clear the cache and statistics with f.cache_clear().
See: http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used
"""
class desc(lazyval):
def __get__(self, instance, owner):
if instance is None:
return self
try:
return self._cache[instance]
except KeyError:
inst = ref(instance)
@_weak_lru_cache(maxsize)
@wraps(self._get)
def wrapper(*args, **kwargs):
return self._get(inst(), *args, **kwargs)
self._cache[instance] = wrapper
return wrapper
@_weak_lru_cache(maxsize)
def __call__(self, *args, **kwargs):
return self._get(*args, **kwargs)
return desc
remember_last = weak_lru_cache(1)