PERF: Look up expired futures from in-memory Futures

instead of queries to the db.
This commit is contained in:
Richard Frank
2016-02-03 16:09:14 -05:00
parent f5b74bf7c5
commit ede1eb7aa0
5 changed files with 108 additions and 36 deletions
+8
View File
@@ -230,6 +230,7 @@ cdef class Future(Asset):
cdef readonly object auto_close_date
cdef readonly object tick_size
cdef readonly float multiplier
cdef readonly object effective_expiration
def __cinit__(self,
int sid, # sid is required
@@ -253,6 +254,13 @@ cdef class Future(Asset):
self.tick_size = tick_size
self.multiplier = multiplier
if notice_date is None:
self.effective_expiration = expiration_date
elif expiration_date is None:
self.effective_expiration = notice_date
else:
self.effective_expiration = min(notice_date, expiration_date)
def __str__(self):
if self.symbol:
return 'Future(%d [%s])' % (self.sid, self.symbol)
-24
View File
@@ -666,30 +666,6 @@ class AssetFinder(object):
contracts = self.retrieve_futures_contracts(sids)
return [contracts[sid] for sid in sids]
def lookup_expired_futures(self, start, end):
if not isinstance(start, pd.Timestamp):
start = pd.Timestamp(start)
start = start.value
if not isinstance(end, pd.Timestamp):
end = pd.Timestamp(end)
end = end.value
fc_cols = self.futures_contracts.c
nd = sa.func.nullif(fc_cols.notice_date, pd.tslib.iNaT)
ed = sa.func.nullif(fc_cols.expiration_date, pd.tslib.iNaT)
date = sa.func.coalesce(sa.func.min(nd, ed), ed, nd)
sids = list(map(
itemgetter('sid'),
sa.select((fc_cols.sid,)).where(
(date >= start) & (date < end)).order_by(
sa.func.coalesce(ed, nd).asc()
).execute().fetchall()
))
return sids
@property
def sids(self):
return tuple(map(
+20 -10
View File
@@ -12,20 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import timedelta
from itertools import takewhile
from contextlib2 import ExitStack
from logbook import Logger, Processor
from pandas.tslib import normalize_date
from zipline.utils.api_support import ZiplineAPI
from zipline.finance.trading import NoFurtherDataError
from zipline.protocol import (
BarData,
SIDData,
DATASOURCE_TYPE
)
from zipline.utils.api_support import ZiplineAPI
from zipline.utils.data import SortedDict
log = Logger('Trade Simulation')
@@ -56,15 +57,23 @@ class AlgorithmSimulator(object):
# Snapshot Setup
# ==============
def _get_effective_expiration(sid,
finder=self.env.asset_finder,
default=self.sim_params.last_close
+ timedelta(days=1)):
asset = finder.retrieve_asset(sid)
return getattr(asset, 'effective_expiration', None) or default
self._get_expiration = _get_effective_expiration
# The algorithm's data as of our most recent event.
# We want an object that will have empty objects as default
# values on missing keys.
self.current_data = BarData()
self.current_data = BarData(SortedDict(self._get_expiration))
# We don't have a datetime for the current snapshot until we
# receive a message.
self.simulation_dt = None
self.previous_dt = self.algo_start
# =============
# Logging Setup
@@ -97,14 +106,15 @@ class AlgorithmSimulator(object):
self._call_before_trading_start(mkt_open)
for date, snapshot in stream_in:
expired_sids = self.env.asset_finder.lookup_expired_futures(
start=self.previous_dt, end=date)
self.previous_dt = date
self.simulation_dt = date
self.on_dt_changed(date)
# removing expired futures
for sid in expired_sids:
expired = list(takewhile(
lambda asset_id: self._get_expiration(asset_id) < date,
self.current_data
))
for sid in expired:
try:
del self.current_data[sid]
except KeyError:
+1 -1
View File
@@ -502,7 +502,7 @@ class BarData(object):
"""
def __init__(self, data=None):
self._data = data or {}
self._data = data if data is not None else {}
self._contains_override = None
def __contains__(self, name):
+79 -1
View File
@@ -13,11 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import bisect
import datetime
from collections import MutableMapping
from copy import deepcopy
try:
from six.moves._thread import get_ident
except ImportError:
from six.moves._dummy_thread import get_ident
import numpy as np
import pandas as pd
from copy import deepcopy
from toolz import merge
def _ensure_index(x):
@@ -391,3 +399,73 @@ class MutableIndexRollingPanel(object):
self.buffer.loc[non_nan_items, :, non_nan_cols])
self.buffer = new_buffer
class SortedDict(MutableMapping):
"""A mapping of key-value pairs sorted by key according to the sort_key
function provided to the mapping. Ties from the sort_key are broken by
comparing the original keys. `iter` traverses the keys in sort order.
Parameters
----------
key : callable
Called on keys in the mapping to produce the values by which those keys
are sorted.
mapping : mapping, optional
**kwargs
The initial mapping.
>>> d = SortedDict(abs)
>>> d[-1] = 'negative one'
>>> d[0] = 'zero'
>>> d[2] = 'two'
>>> d
SortedDict[<built-in function abs>]([(0, 'zero'), (-1, 'negative one'), (1, 'one')]) # noqa
>>> d[1] = 'one' # Mutating the mapping maintains the sort order.
>>> d
SortedDict[<built-in function abs>]([(0, 'zero'), (-1, 'negative one'), (1, 'one')]) # noqa
"""
def __init__(self, sort_key, mapping=None, **kwargs):
self._map = {}
self._sorted_key_names = []
self._sort_key = sort_key
self.update(merge(mapping or {}, kwargs))
def __getitem__(self, name):
return self._map[name]
def __setitem__(self, name, value, _bisect_right=bisect.bisect_right):
self._map[name] = value
if len(self._map) > len(self._sorted_key_names):
key = self._sort_key(name)
pair = (key, name)
idx = _bisect_right(self._sorted_key_names, pair)
self._sorted_key_names.insert(idx, pair)
def __delitem__(self, name, _bisect_left=bisect.bisect_left):
del self._map[name]
idx = _bisect_left(self._sorted_key_names,
(self._sort_key(name), name))
del self._sorted_key_names[idx]
def __iter__(self):
for key, name in self._sorted_key_names:
yield name
def __len__(self):
return len(self._map)
def __repr__(self, _repr_running={}):
# Based on OrderedDict/defaultdict
call_key = id(self), get_ident()
if call_key in _repr_running:
return '...'
_repr_running[call_key] = 1
try:
if not self:
return '%s[%r]()' % (self.__class__.__name__, self._sort_key)
return '%s[%r](%r)' % (self.__class__.__name__, self._sort_key,
self.items())
finally:
del _repr_running[call_key]