mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 14:04:47 +08:00
PERF: Look up expired futures from in-memory Futures
instead of queries to the db.
This commit is contained in:
@@ -230,6 +230,7 @@ cdef class Future(Asset):
|
||||
cdef readonly object auto_close_date
|
||||
cdef readonly object tick_size
|
||||
cdef readonly float multiplier
|
||||
cdef readonly object effective_expiration
|
||||
|
||||
def __cinit__(self,
|
||||
int sid, # sid is required
|
||||
@@ -253,6 +254,13 @@ cdef class Future(Asset):
|
||||
self.tick_size = tick_size
|
||||
self.multiplier = multiplier
|
||||
|
||||
if notice_date is None:
|
||||
self.effective_expiration = expiration_date
|
||||
elif expiration_date is None:
|
||||
self.effective_expiration = notice_date
|
||||
else:
|
||||
self.effective_expiration = min(notice_date, expiration_date)
|
||||
|
||||
def __str__(self):
|
||||
if self.symbol:
|
||||
return 'Future(%d [%s])' % (self.sid, self.symbol)
|
||||
|
||||
@@ -666,30 +666,6 @@ class AssetFinder(object):
|
||||
contracts = self.retrieve_futures_contracts(sids)
|
||||
return [contracts[sid] for sid in sids]
|
||||
|
||||
def lookup_expired_futures(self, start, end):
|
||||
if not isinstance(start, pd.Timestamp):
|
||||
start = pd.Timestamp(start)
|
||||
start = start.value
|
||||
if not isinstance(end, pd.Timestamp):
|
||||
end = pd.Timestamp(end)
|
||||
end = end.value
|
||||
|
||||
fc_cols = self.futures_contracts.c
|
||||
|
||||
nd = sa.func.nullif(fc_cols.notice_date, pd.tslib.iNaT)
|
||||
ed = sa.func.nullif(fc_cols.expiration_date, pd.tslib.iNaT)
|
||||
date = sa.func.coalesce(sa.func.min(nd, ed), ed, nd)
|
||||
|
||||
sids = list(map(
|
||||
itemgetter('sid'),
|
||||
sa.select((fc_cols.sid,)).where(
|
||||
(date >= start) & (date < end)).order_by(
|
||||
sa.func.coalesce(ed, nd).asc()
|
||||
).execute().fetchall()
|
||||
))
|
||||
|
||||
return sids
|
||||
|
||||
@property
|
||||
def sids(self):
|
||||
return tuple(map(
|
||||
|
||||
@@ -12,20 +12,21 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from datetime import timedelta
|
||||
from itertools import takewhile
|
||||
|
||||
from contextlib2 import ExitStack
|
||||
|
||||
from logbook import Logger, Processor
|
||||
from pandas.tslib import normalize_date
|
||||
|
||||
from zipline.utils.api_support import ZiplineAPI
|
||||
|
||||
from zipline.finance.trading import NoFurtherDataError
|
||||
from zipline.protocol import (
|
||||
BarData,
|
||||
SIDData,
|
||||
DATASOURCE_TYPE
|
||||
)
|
||||
from zipline.utils.api_support import ZiplineAPI
|
||||
from zipline.utils.data import SortedDict
|
||||
|
||||
log = Logger('Trade Simulation')
|
||||
|
||||
@@ -56,15 +57,23 @@ class AlgorithmSimulator(object):
|
||||
# Snapshot Setup
|
||||
# ==============
|
||||
|
||||
def _get_effective_expiration(sid,
|
||||
finder=self.env.asset_finder,
|
||||
default=self.sim_params.last_close
|
||||
+ timedelta(days=1)):
|
||||
asset = finder.retrieve_asset(sid)
|
||||
return getattr(asset, 'effective_expiration', None) or default
|
||||
|
||||
self._get_expiration = _get_effective_expiration
|
||||
|
||||
# The algorithm's data as of our most recent event.
|
||||
# We want an object that will have empty objects as default
|
||||
# values on missing keys.
|
||||
self.current_data = BarData()
|
||||
self.current_data = BarData(SortedDict(self._get_expiration))
|
||||
|
||||
# We don't have a datetime for the current snapshot until we
|
||||
# receive a message.
|
||||
self.simulation_dt = None
|
||||
self.previous_dt = self.algo_start
|
||||
|
||||
# =============
|
||||
# Logging Setup
|
||||
@@ -97,14 +106,15 @@ class AlgorithmSimulator(object):
|
||||
self._call_before_trading_start(mkt_open)
|
||||
|
||||
for date, snapshot in stream_in:
|
||||
expired_sids = self.env.asset_finder.lookup_expired_futures(
|
||||
start=self.previous_dt, end=date)
|
||||
self.previous_dt = date
|
||||
|
||||
self.simulation_dt = date
|
||||
self.on_dt_changed(date)
|
||||
|
||||
# removing expired futures
|
||||
for sid in expired_sids:
|
||||
expired = list(takewhile(
|
||||
lambda asset_id: self._get_expiration(asset_id) < date,
|
||||
self.current_data
|
||||
))
|
||||
for sid in expired:
|
||||
try:
|
||||
del self.current_data[sid]
|
||||
except KeyError:
|
||||
|
||||
+1
-1
@@ -502,7 +502,7 @@ class BarData(object):
|
||||
"""
|
||||
|
||||
def __init__(self, data=None):
|
||||
self._data = data or {}
|
||||
self._data = data if data is not None else {}
|
||||
self._contains_override = None
|
||||
|
||||
def __contains__(self, name):
|
||||
|
||||
+79
-1
@@ -13,11 +13,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import bisect
|
||||
import datetime
|
||||
from collections import MutableMapping
|
||||
from copy import deepcopy
|
||||
|
||||
try:
|
||||
from six.moves._thread import get_ident
|
||||
except ImportError:
|
||||
from six.moves._dummy_thread import get_ident
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from copy import deepcopy
|
||||
from toolz import merge
|
||||
|
||||
|
||||
def _ensure_index(x):
|
||||
@@ -391,3 +399,73 @@ class MutableIndexRollingPanel(object):
|
||||
self.buffer.loc[non_nan_items, :, non_nan_cols])
|
||||
|
||||
self.buffer = new_buffer
|
||||
|
||||
|
||||
class SortedDict(MutableMapping):
|
||||
"""A mapping of key-value pairs sorted by key according to the sort_key
|
||||
function provided to the mapping. Ties from the sort_key are broken by
|
||||
comparing the original keys. `iter` traverses the keys in sort order.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : callable
|
||||
Called on keys in the mapping to produce the values by which those keys
|
||||
are sorted.
|
||||
mapping : mapping, optional
|
||||
**kwargs
|
||||
The initial mapping.
|
||||
|
||||
>>> d = SortedDict(abs)
|
||||
>>> d[-1] = 'negative one'
|
||||
>>> d[0] = 'zero'
|
||||
>>> d[2] = 'two'
|
||||
>>> d
|
||||
SortedDict[<built-in function abs>]([(0, 'zero'), (-1, 'negative one'), (1, 'one')]) # noqa
|
||||
>>> d[1] = 'one' # Mutating the mapping maintains the sort order.
|
||||
>>> d
|
||||
SortedDict[<built-in function abs>]([(0, 'zero'), (-1, 'negative one'), (1, 'one')]) # noqa
|
||||
"""
|
||||
def __init__(self, sort_key, mapping=None, **kwargs):
|
||||
self._map = {}
|
||||
self._sorted_key_names = []
|
||||
self._sort_key = sort_key
|
||||
|
||||
self.update(merge(mapping or {}, kwargs))
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self._map[name]
|
||||
|
||||
def __setitem__(self, name, value, _bisect_right=bisect.bisect_right):
|
||||
self._map[name] = value
|
||||
if len(self._map) > len(self._sorted_key_names):
|
||||
key = self._sort_key(name)
|
||||
pair = (key, name)
|
||||
idx = _bisect_right(self._sorted_key_names, pair)
|
||||
self._sorted_key_names.insert(idx, pair)
|
||||
|
||||
def __delitem__(self, name, _bisect_left=bisect.bisect_left):
|
||||
del self._map[name]
|
||||
idx = _bisect_left(self._sorted_key_names,
|
||||
(self._sort_key(name), name))
|
||||
del self._sorted_key_names[idx]
|
||||
|
||||
def __iter__(self):
|
||||
for key, name in self._sorted_key_names:
|
||||
yield name
|
||||
|
||||
def __len__(self):
|
||||
return len(self._map)
|
||||
|
||||
def __repr__(self, _repr_running={}):
|
||||
# Based on OrderedDict/defaultdict
|
||||
call_key = id(self), get_ident()
|
||||
if call_key in _repr_running:
|
||||
return '...'
|
||||
_repr_running[call_key] = 1
|
||||
try:
|
||||
if not self:
|
||||
return '%s[%r]()' % (self.__class__.__name__, self._sort_key)
|
||||
return '%s[%r](%r)' % (self.__class__.__name__, self._sort_key,
|
||||
self.items())
|
||||
finally:
|
||||
del _repr_running[call_key]
|
||||
|
||||
Reference in New Issue
Block a user