From ede1eb7aa021a3587da7b2a5bd2ba14daa763e58 Mon Sep 17 00:00:00 2001 From: Richard Frank Date: Wed, 3 Feb 2016 16:09:14 -0500 Subject: [PATCH] PERF: Look up expired futures from in-memory Futures instead of queries to the db. --- zipline/assets/_assets.pyx | 8 ++++ zipline/assets/assets.py | 24 ---------- zipline/gens/tradesimulation.py | 30 ++++++++----- zipline/protocol.py | 2 +- zipline/utils/data.py | 80 ++++++++++++++++++++++++++++++++- 5 files changed, 108 insertions(+), 36 deletions(-) diff --git a/zipline/assets/_assets.pyx b/zipline/assets/_assets.pyx index 371253f8..e3cac229 100644 --- a/zipline/assets/_assets.pyx +++ b/zipline/assets/_assets.pyx @@ -230,6 +230,7 @@ cdef class Future(Asset): cdef readonly object auto_close_date cdef readonly object tick_size cdef readonly float multiplier + cdef readonly object effective_expiration def __cinit__(self, int sid, # sid is required @@ -253,6 +254,13 @@ cdef class Future(Asset): self.tick_size = tick_size self.multiplier = multiplier + if notice_date is None: + self.effective_expiration = expiration_date + elif expiration_date is None: + self.effective_expiration = notice_date + else: + self.effective_expiration = min(notice_date, expiration_date) + def __str__(self): if self.symbol: return 'Future(%d [%s])' % (self.sid, self.symbol) diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index 3124cab0..1af5b571 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -666,30 +666,6 @@ class AssetFinder(object): contracts = self.retrieve_futures_contracts(sids) return [contracts[sid] for sid in sids] - def lookup_expired_futures(self, start, end): - if not isinstance(start, pd.Timestamp): - start = pd.Timestamp(start) - start = start.value - if not isinstance(end, pd.Timestamp): - end = pd.Timestamp(end) - end = end.value - - fc_cols = self.futures_contracts.c - - nd = sa.func.nullif(fc_cols.notice_date, pd.tslib.iNaT) - ed = sa.func.nullif(fc_cols.expiration_date, pd.tslib.iNaT) - date = sa.func.coalesce(sa.func.min(nd, ed), ed, nd) - - sids = list(map( - itemgetter('sid'), - sa.select((fc_cols.sid,)).where( - (date >= start) & (date < end)).order_by( - sa.func.coalesce(ed, nd).asc() - ).execute().fetchall() - )) - - return sids - @property def sids(self): return tuple(map( diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index f720333b..3a8459c1 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -12,20 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from datetime import timedelta +from itertools import takewhile from contextlib2 import ExitStack - from logbook import Logger, Processor from pandas.tslib import normalize_date -from zipline.utils.api_support import ZiplineAPI - from zipline.finance.trading import NoFurtherDataError from zipline.protocol import ( BarData, SIDData, DATASOURCE_TYPE ) +from zipline.utils.api_support import ZiplineAPI +from zipline.utils.data import SortedDict log = Logger('Trade Simulation') @@ -56,15 +57,23 @@ class AlgorithmSimulator(object): # Snapshot Setup # ============== + def _get_effective_expiration(sid, + finder=self.env.asset_finder, + default=self.sim_params.last_close + + timedelta(days=1)): + asset = finder.retrieve_asset(sid) + return getattr(asset, 'effective_expiration', None) or default + + self._get_expiration = _get_effective_expiration + # The algorithm's data as of our most recent event. # We want an object that will have empty objects as default # values on missing keys. - self.current_data = BarData() + self.current_data = BarData(SortedDict(self._get_expiration)) # We don't have a datetime for the current snapshot until we # receive a message. self.simulation_dt = None - self.previous_dt = self.algo_start # ============= # Logging Setup @@ -97,14 +106,15 @@ class AlgorithmSimulator(object): self._call_before_trading_start(mkt_open) for date, snapshot in stream_in: - expired_sids = self.env.asset_finder.lookup_expired_futures( - start=self.previous_dt, end=date) - self.previous_dt = date + self.simulation_dt = date self.on_dt_changed(date) - # removing expired futures - for sid in expired_sids: + expired = list(takewhile( + lambda asset_id: self._get_expiration(asset_id) < date, + self.current_data + )) + for sid in expired: try: del self.current_data[sid] except KeyError: diff --git a/zipline/protocol.py b/zipline/protocol.py index cf4aca63..0355c54a 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -502,7 +502,7 @@ class BarData(object): """ def __init__(self, data=None): - self._data = data or {} + self._data = data if data is not None else {} self._contains_override = None def __contains__(self, name): diff --git a/zipline/utils/data.py b/zipline/utils/data.py index d81fd87d..1f4788b5 100644 --- a/zipline/utils/data.py +++ b/zipline/utils/data.py @@ -13,11 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import bisect import datetime +from collections import MutableMapping +from copy import deepcopy + +try: + from six.moves._thread import get_ident +except ImportError: + from six.moves._dummy_thread import get_ident import numpy as np import pandas as pd -from copy import deepcopy +from toolz import merge def _ensure_index(x): @@ -391,3 +399,73 @@ class MutableIndexRollingPanel(object): self.buffer.loc[non_nan_items, :, non_nan_cols]) self.buffer = new_buffer + + +class SortedDict(MutableMapping): + """A mapping of key-value pairs sorted by key according to the sort_key + function provided to the mapping. Ties from the sort_key are broken by + comparing the original keys. `iter` traverses the keys in sort order. + + Parameters + ---------- + key : callable + Called on keys in the mapping to produce the values by which those keys + are sorted. + mapping : mapping, optional + **kwargs + The initial mapping. + + >>> d = SortedDict(abs) + >>> d[-1] = 'negative one' + >>> d[0] = 'zero' + >>> d[2] = 'two' + >>> d + SortedDict[]([(0, 'zero'), (-1, 'negative one'), (1, 'one')]) # noqa + >>> d[1] = 'one' # Mutating the mapping maintains the sort order. + >>> d + SortedDict[]([(0, 'zero'), (-1, 'negative one'), (1, 'one')]) # noqa + """ + def __init__(self, sort_key, mapping=None, **kwargs): + self._map = {} + self._sorted_key_names = [] + self._sort_key = sort_key + + self.update(merge(mapping or {}, kwargs)) + + def __getitem__(self, name): + return self._map[name] + + def __setitem__(self, name, value, _bisect_right=bisect.bisect_right): + self._map[name] = value + if len(self._map) > len(self._sorted_key_names): + key = self._sort_key(name) + pair = (key, name) + idx = _bisect_right(self._sorted_key_names, pair) + self._sorted_key_names.insert(idx, pair) + + def __delitem__(self, name, _bisect_left=bisect.bisect_left): + del self._map[name] + idx = _bisect_left(self._sorted_key_names, + (self._sort_key(name), name)) + del self._sorted_key_names[idx] + + def __iter__(self): + for key, name in self._sorted_key_names: + yield name + + def __len__(self): + return len(self._map) + + def __repr__(self, _repr_running={}): + # Based on OrderedDict/defaultdict + call_key = id(self), get_ident() + if call_key in _repr_running: + return '...' + _repr_running[call_key] = 1 + try: + if not self: + return '%s[%r]()' % (self.__class__.__name__, self._sort_key) + return '%s[%r](%r)' % (self.__class__.__name__, self._sort_key, + self.items()) + finally: + del _repr_running[call_key]