diff --git a/docs/source/whatsnew/0.9.1.txt b/docs/source/whatsnew/0.9.1.txt index 1f5791d7..b115b5ee 100644 --- a/docs/source/whatsnew/0.9.1.txt +++ b/docs/source/whatsnew/0.9.1.txt @@ -22,6 +22,11 @@ Enhancements factor to only compute over stocks for which the filter returns True, rather than always computing over the entire universe of stocks. (:issue:`1095`) +* Added :class:`zipline.utils.cache.ExpiringCache`. + A cache which wraps entries in a :class:`zipline.utils.cache.CachedObject`, + which manages expiration of entries based on the `dt` supplied to the `get` + method. + Experimental Features ~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/utils/test_cache.py b/tests/utils/test_cache.py index c69a0206..19c4af88 100644 --- a/tests/utils/test_cache.py +++ b/tests/utils/test_cache.py @@ -2,7 +2,7 @@ from unittest import TestCase from pandas import Timestamp, Timedelta -from zipline.utils.cache import CachedObject, Expired +from zipline.utils.cache import CachedObject, Expired, ExpiringCache class CachedObjectTestCase(TestCase): @@ -19,3 +19,42 @@ class CachedObjectTestCase(TestCase): with self.assertRaises(Expired) as e: obj.unwrap(after) self.assertEqual(e.exception.args, (expiry,)) + + +class ExpiringCacheTestCase(TestCase): + + def test_expiring_cache(self): + expiry_1 = Timestamp('2014') + before_1 = expiry_1 - Timedelta('1 minute') + after_1 = expiry_1 + Timedelta('1 minute') + + expiry_2 = Timestamp('2015') + after_2 = expiry_1 + Timedelta('1 minute') + + expiry_3 = Timestamp('2016') + + cache = ExpiringCache() + + cache.set('foo', 1, expiry_1) + cache.set('bar', 2, expiry_2) + + self.assertEqual(cache.get('foo', before_1), 1) + # Unwrap on expiry is allowed. + self.assertEqual(cache.get('foo', expiry_1), 1) + + with self.assertRaises(KeyError) as e: + self.assertEqual(cache.get('foo', after_1)) + self.assertEqual(e.exception.args, ('foo',)) + + # Should raise same KeyError after deletion. + with self.assertRaises(KeyError) as e: + self.assertEqual(cache.get('foo', before_1)) + self.assertEqual(e.exception.args, ('foo',)) + + # Second value should still exist. + self.assertEqual(cache.get('bar', after_2), 2) + + # Should raise similar KeyError on non-existent key. + with self.assertRaises(KeyError) as e: + self.assertEqual(cache.get('baz', expiry_3)) + self.assertEqual(e.exception.args, ('baz',)) diff --git a/zipline/utils/cache.py b/zipline/utils/cache.py index 9ce6ade7..9f9f1e53 100644 --- a/zipline/utils/cache.py +++ b/zipline/utils/cache.py @@ -57,3 +57,58 @@ class CachedObject(namedtuple("_CachedObject", "value expires")): if dt > self.expires: raise Expired(self.expires) return self.value + + +class ExpiringCache(object): + """ + A cache of multiple CachedObjects, which returns the wrapped the value + or raises and deletes the CachedObject if the value has expired. + + Parameters + ---------- + cache : dict-like + An instance of a dict-like object which needs to support at least: + `__del__`, `__getitem__`, `__setitem__` + If `None`, than a dict is used as a default. + + Methods + ------- + get(self, key, dt) + Get the value of a cached object for the given `key` at `dt`, if the + CachedObject has expired then the object is removed from the cache, + and `KeyError` is raised. + + set(self, key, value, expiration_dt) + Add a new `value` to the cache at `dt` wrapped in a CachedObject which + expires at `expiration_dt`. + + Usage + ----- + >>> from pandas import Timestamp, Timedelta + >>> expires = Timestamp('2014', tz='UTC') + >>> value = 1 + >>> cache = ExpiringCache() + >>> cache.set('foo', value, expires) + >>> cache.get('foo', expires - Timedelta('1 minute')) + 1 + >>> cache.get('foo', expires + Timedelta('1 minute')) + Traceback (most recent call last): + ... + KeyError: 'foo' + """ + + def __init__(self, cache=None): + if cache is not None: + self._cache = cache + else: + self._cache = {} + + def get(self, key, dt): + try: + return self._cache[key].unwrap(dt) + except Expired: + del self._cache[key] + raise KeyError(key) + + def set(self, key, value, expiration_dt): + self._cache[key] = CachedObject(value, expiration_dt)