mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 19:14:36 +08:00
PERF: Reimplemented remember_last with a weak_lru_cache
which won't leak instances whose methods have been decorated (specifically DataPortal instances) MAINT: Not using functools32 anymore
This commit is contained in:
@@ -33,9 +33,6 @@ cyordereddict==0.2.2
|
||||
# faster array ops.
|
||||
bottleneck==1.0.0
|
||||
|
||||
# lru_cache
|
||||
functools32==3.2.3.post2;python_version<'3.0'
|
||||
|
||||
contextlib2==0.4.0
|
||||
|
||||
# networkx requires decorator
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
Tests for zipline.utils.memoize.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
import gc
|
||||
from unittest import TestCase
|
||||
|
||||
from zipline.utils.memoize import remember_last
|
||||
@@ -32,3 +34,63 @@ class TestRememberLast(TestCase):
|
||||
# Calling the old value should still increment the counter.
|
||||
self.assertEqual((func(1), call_count[0]), (1, 3))
|
||||
self.assertEqual((func(1), call_count[0]), (1, 3))
|
||||
|
||||
def test_remember_last_method(self):
|
||||
call_count = defaultdict(int)
|
||||
|
||||
class clz(object):
|
||||
@remember_last
|
||||
def func(self, x):
|
||||
call_count[(self, x)] += 1
|
||||
return x
|
||||
|
||||
inst1 = clz()
|
||||
self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 1}))
|
||||
|
||||
# Calling again with the same argument should just re-use the old
|
||||
# value, which means func shouldn't get called again.
|
||||
self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 1}))
|
||||
|
||||
# Calling with a new value should increment the counter.
|
||||
self.assertEqual((inst1.func(2), call_count), (2, {(inst1, 1): 1,
|
||||
(inst1, 2): 1}))
|
||||
self.assertEqual((inst1.func(2), call_count), (2, {(inst1, 1): 1,
|
||||
(inst1, 2): 1}))
|
||||
|
||||
# Calling the old value should still increment the counter.
|
||||
self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 2,
|
||||
(inst1, 2): 1}))
|
||||
self.assertEqual((inst1.func(1), call_count), (1, {(inst1, 1): 2,
|
||||
(inst1, 2): 1}))
|
||||
|
||||
inst2 = clz()
|
||||
self.assertEqual((inst2.func(1), call_count),
|
||||
(1, {(inst1, 1): 2, (inst1, 2): 1,
|
||||
(inst2, 1): 1}))
|
||||
self.assertEqual((inst2.func(1), call_count),
|
||||
(1, {(inst1, 1): 2, (inst1, 2): 1,
|
||||
(inst2, 1): 1}))
|
||||
|
||||
self.assertEqual((inst2.func(2), call_count),
|
||||
(2, {(inst1, 1): 2, (inst1, 2): 1,
|
||||
(inst2, 1): 1, (inst2, 2): 1}))
|
||||
self.assertEqual((inst2.func(2), call_count),
|
||||
(2, {(inst1, 1): 2, (inst1, 2): 1,
|
||||
(inst2, 1): 1, (inst2, 2): 1}))
|
||||
|
||||
self.assertEqual((inst2.func(1), call_count),
|
||||
(1, {(inst1, 1): 2, (inst1, 2): 1,
|
||||
(inst2, 1): 2, (inst2, 2): 1}))
|
||||
self.assertEqual((inst2.func(1), call_count),
|
||||
(1, {(inst1, 1): 2, (inst1, 2): 1,
|
||||
(inst2, 1): 2, (inst2, 2): 1}))
|
||||
|
||||
# Remove the above references to the instances and ensure that
|
||||
# remember_last has not made its own.
|
||||
del inst1, inst2
|
||||
call_count.clear()
|
||||
while gc.collect():
|
||||
pass
|
||||
|
||||
self.assertFalse([inst for inst in gc.get_objects()
|
||||
if type(inst) == clz])
|
||||
|
||||
@@ -31,13 +31,12 @@ from zipline.data.us_equity_loader import (
|
||||
)
|
||||
|
||||
from zipline.utils import tradingcalendar
|
||||
from zipline.utils.compat import lru_cache
|
||||
from zipline.utils.math_utils import (
|
||||
nansum,
|
||||
nanmean,
|
||||
nanstd
|
||||
)
|
||||
from zipline.utils.memoize import remember_last
|
||||
from zipline.utils.memoize import remember_last, weak_lru_cache
|
||||
from zipline.errors import (
|
||||
NoTradeDataAvailableTooEarly,
|
||||
NoTradeDataAvailableTooLate,
|
||||
@@ -1686,7 +1685,7 @@ class DataPortal(object):
|
||||
else:
|
||||
return [assets] if isinstance(assets, Asset) else []
|
||||
|
||||
@lru_cache(20)
|
||||
@weak_lru_cache(20)
|
||||
def _get_minute_count_for_transform(self, ending_minute, days_count):
|
||||
# cache size picked somewhat loosely. this code exists purely to
|
||||
# handle deprecated API.
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from six import PY2
|
||||
|
||||
|
||||
if PY2:
|
||||
from functools32 import lru_cache # noqa
|
||||
else:
|
||||
from functools import lru_cache # noqa
|
||||
+215
-3
@@ -1,8 +1,13 @@
|
||||
"""
|
||||
Tools for memoization of function results.
|
||||
"""
|
||||
from zipline.utils.compat import lru_cache
|
||||
from weakref import WeakKeyDictionary
|
||||
from collections import OrderedDict, Sequence
|
||||
from functools import wraps
|
||||
from itertools import compress
|
||||
from weakref import WeakKeyDictionary, ref
|
||||
|
||||
from six.moves._thread import allocate_lock as Lock
|
||||
from toolz.sandbox import unzip
|
||||
|
||||
|
||||
class lazyval(object):
|
||||
@@ -84,4 +89,211 @@ class classlazyval(lazyval):
|
||||
return super(classlazyval, self).__get__(owner, owner)
|
||||
|
||||
|
||||
remember_last = lru_cache(1)
|
||||
def _weak_lru_cache(maxsize=100):
|
||||
"""
|
||||
Users should only access the lru_cache through its public API:
|
||||
cache_info, cache_clear
|
||||
The internals of the lru_cache are encapsulated for thread safety and
|
||||
to allow the implementation to change.
|
||||
"""
|
||||
def decorating_function(
|
||||
user_function, tuple=tuple, sorted=sorted, len=len,
|
||||
KeyError=KeyError):
|
||||
|
||||
hits, misses = [0], [0]
|
||||
kwd_mark = (object(),) # separates positional and keyword args
|
||||
lock = Lock() # needed because OrderedDict isn't threadsafe
|
||||
|
||||
if maxsize is None:
|
||||
cache = _WeakArgsDict() # cache without ordering or size limit
|
||||
|
||||
@wraps(user_function)
|
||||
def wrapper(*args, **kwds):
|
||||
key = args
|
||||
if kwds:
|
||||
key += kwd_mark + tuple(sorted(kwds.items()))
|
||||
try:
|
||||
result = cache[key]
|
||||
hits[0] += 1
|
||||
return result
|
||||
except KeyError:
|
||||
pass
|
||||
result = user_function(*args, **kwds)
|
||||
cache[key] = result
|
||||
misses[0] += 1
|
||||
return result
|
||||
else:
|
||||
# ordered least recent to most recent
|
||||
cache = _WeakArgsOrderedDict()
|
||||
cache_popitem = cache.popitem
|
||||
cache_renew = cache.move_to_end
|
||||
|
||||
@wraps(user_function)
|
||||
def wrapper(*args, **kwds):
|
||||
key = args
|
||||
if kwds:
|
||||
key += kwd_mark + tuple(sorted(kwds.items()))
|
||||
with lock:
|
||||
try:
|
||||
result = cache[key]
|
||||
cache_renew(key) # record recent use of this key
|
||||
hits[0] += 1
|
||||
return result
|
||||
except KeyError:
|
||||
pass
|
||||
result = user_function(*args, **kwds)
|
||||
with lock:
|
||||
cache[key] = result # record recent use of this key
|
||||
misses[0] += 1
|
||||
if len(cache) > maxsize:
|
||||
# purge least recently used cache entry
|
||||
cache_popitem(False)
|
||||
return result
|
||||
|
||||
def cache_info():
|
||||
"""Report cache statistics"""
|
||||
with lock:
|
||||
return hits[0], misses[0], maxsize, len(cache)
|
||||
|
||||
def cache_clear():
|
||||
"""Clear the cache and cache statistics"""
|
||||
with lock:
|
||||
cache.clear()
|
||||
hits[0] = misses[0] = 0
|
||||
|
||||
wrapper.cache_info = cache_info
|
||||
wrapper.cache_clear = cache_clear
|
||||
return wrapper
|
||||
|
||||
return decorating_function
|
||||
|
||||
|
||||
class _WeakArgs(Sequence):
|
||||
"""
|
||||
Works with _WeakArgsDict to provide a weak cache for function args.
|
||||
When any of those args are gc'd, the pair is removed from the cache.
|
||||
"""
|
||||
def __init__(self, items, dict_remove=None):
|
||||
def remove(k, selfref=ref(self), dict_remove=dict_remove):
|
||||
self = selfref()
|
||||
if self is not None and dict_remove is not None:
|
||||
dict_remove(self)
|
||||
|
||||
self._items, self._selectors = unzip(self._try_ref(item, remove)
|
||||
for item in items)
|
||||
self._items = tuple(self._items)
|
||||
self._selectors = tuple(self._selectors)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._items[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._items)
|
||||
|
||||
@staticmethod
|
||||
def _try_ref(item, callback):
|
||||
try:
|
||||
return ref(item, callback), True
|
||||
except TypeError:
|
||||
return item, False
|
||||
|
||||
@property
|
||||
def alive(self):
|
||||
return all(item() is not None
|
||||
for item in compress(self._items, self._selectors))
|
||||
|
||||
def __eq__(self, other):
|
||||
return self._items == other._items
|
||||
|
||||
def __hash__(self):
|
||||
try:
|
||||
return self.__hash
|
||||
except AttributeError:
|
||||
h = self.__hash = hash(self._items)
|
||||
return h
|
||||
|
||||
|
||||
class _WeakArgsDict(WeakKeyDictionary, object):
|
||||
def __delitem__(self, key):
|
||||
del self.data[_WeakArgs(key)]
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[_WeakArgs(key)]
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%r)' % (type(self).__name__, self.data)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[_WeakArgs(key, self._remove)] = value
|
||||
|
||||
def __contains__(self, key):
|
||||
try:
|
||||
wr = _WeakArgs(key)
|
||||
except TypeError:
|
||||
return False
|
||||
return wr in self.data
|
||||
|
||||
def pop(self, key, *args):
|
||||
return self.data.pop(_WeakArgs(key), *args)
|
||||
|
||||
|
||||
class _WeakArgsOrderedDict(_WeakArgsDict, object):
|
||||
def __init__(self):
|
||||
super(_WeakArgsOrderedDict, self).__init__()
|
||||
self.data = OrderedDict()
|
||||
|
||||
def popitem(self, last=True):
|
||||
while True:
|
||||
key, value = self.data.popitem(last)
|
||||
if key.alive:
|
||||
return tuple(key), value
|
||||
|
||||
def move_to_end(self, key):
|
||||
"""Move an existing element to the end.
|
||||
|
||||
Raises KeyError if the element does not exist.
|
||||
"""
|
||||
self[key] = self.pop(key)
|
||||
|
||||
|
||||
def weak_lru_cache(maxsize=100):
|
||||
"""Weak least-recently-used cache decorator.
|
||||
|
||||
If *maxsize* is set to None, the LRU features are disabled and the cache
|
||||
can grow without bound.
|
||||
|
||||
Arguments to the cached function must be hashable. Any that are weak-
|
||||
referenceable will be stored by weak reference. Once any of the args have
|
||||
been garbage collected, the entry will be removed from the cache.
|
||||
|
||||
View the cache statistics named tuple (hits, misses, maxsize, currsize)
|
||||
with f.cache_info(). Clear the cache and statistics with f.cache_clear().
|
||||
|
||||
See: http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used
|
||||
|
||||
"""
|
||||
class desc(lazyval):
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return self
|
||||
try:
|
||||
return self._cache[instance]
|
||||
except KeyError:
|
||||
inst = ref(instance)
|
||||
|
||||
@_weak_lru_cache(maxsize)
|
||||
@wraps(self._get)
|
||||
def wrapper(*args, **kwargs):
|
||||
return self._get(inst(), *args, **kwargs)
|
||||
|
||||
self._cache[instance] = wrapper
|
||||
return wrapper
|
||||
|
||||
@_weak_lru_cache(maxsize)
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._get(*args, **kwargs)
|
||||
|
||||
return desc
|
||||
|
||||
|
||||
remember_last = weak_lru_cache(1)
|
||||
|
||||
Reference in New Issue
Block a user