ENH: adds lookup_expired_futures to asset_finder

This commit is contained in:
warren-oneill
2015-11-26 16:46:28 +01:00
parent 987d6d4e48
commit ef323a3165
6 changed files with 51 additions and 34 deletions
+14 -13
View File
@@ -1973,22 +1973,24 @@ class TestRemoveData(TestCase):
tests if futures data is removed after expiry
"""
def setUp(self):
dt = pd.Timestamp('2015-01-01', tz='UTC')
metadata = {0: {'symbol': 'X',
'expiration_date': dt + timedelta(days=5),
'end_date': dt + timedelta(days=5)},
1: {'symbol': 'Y',
'expiration_date': dt + timedelta(days=7),
'end_date': dt + timedelta(days=7)}}
dt = pd.Timestamp('2015-01-02', tz='UTC')
env = TradingEnvironment()
ix = env.trading_days.get_loc(dt)
metadata = {0: {'symbol': 'X',
'expiration_date': env.trading_days[ix + 5],
'end_date': env.trading_days[ix + 6]},
1: {'symbol': 'Y',
'expiration_date': env.trading_days[ix + 7],
'end_date': env.trading_days[ix + 8]}}
env.write_data(futures_data=metadata)
index_x = pd.date_range(dt, periods=5)
index_x = env.trading_days[ix:ix + 5]
data_x = pd.DataFrame([[1, 100], [2, 100], [3, 100], [4, 100],
[5, 100]],
index=index_x, columns=['price', 'volume'])
index_y = index_x.shift(2)
index_y = env.trading_days[ix:ix + 5].shift(2)
data_y = pd.DataFrame([[6, 100], [7, 100], [8, 100], [9, 100],
[10, 100]],
index=index_y, columns=['price', 'volume'])
@@ -2000,8 +2002,7 @@ class TestRemoveData(TestCase):
def test_remove_data(self):
self.algo.run(self.source)
expected_length = [1, 2, 2, 2, 2, 1]
expected_lengths = [1, 1, 2, 2, 2, 2, 1]
# initially only data for X should be sent and on the last day only
# data for Y should be sent since X is expired
for i, length in enumerate(self.algo.data):
self.assertEqual(expected_length[i], length, i)
np.testing.assert_array_equal(self.algo.data, expected_lengths)
+4 -3
View File
@@ -15,6 +15,7 @@
import unittest
import datetime
import pandas as pd
import pytz
import numpy as np
@@ -77,9 +78,9 @@ class TestEventsThroughRisk(unittest.TestCase):
algo = BuyAndHoldAlgorithm(sim_params=sim_params, env=self.env)
first_date = datetime.datetime(2006, 1, 3, tzinfo=pytz.utc)
second_date = datetime.datetime(2006, 1, 4, tzinfo=pytz.utc)
third_date = datetime.datetime(2006, 1, 5, tzinfo=pytz.utc)
first_date = pd.Timestamp('2006-01-03', tz='UTC')
second_date = pd.Timestamp('2006-01-04', tz='UTC')
third_date = pd.Timestamp('2006-01-05', tz='UTC')
trade_bar_data = [
Event({
+1 -6
View File
@@ -123,12 +123,7 @@ class TestDataFrameSource(TestCase):
self.assertEqual(5, event.sid)
event = next(source)
self.assertEqual(4, event.sid)
try:
x = False
event = next(source)
except StopIteration:
x = True
self.assertTrue(x)
self.assertRaises(StopIteration, next, source)
class TestRandomWalkSource(TestCase):
+24 -3
View File
@@ -266,9 +266,9 @@ class AssetFinder(object):
self.equities.c.share_class_symbol ==
share_class_symbol,
self.equities.c.start_date <= ad_value),
).order_by(
self.equities.c.end_date.desc(),
).execute().fetchall()
).order_by(
self.equities.c.end_date.desc(),
).execute().fetchall()
return candidates
def _get_best_candidate(self, candidates):
@@ -492,6 +492,26 @@ class AssetFinder(object):
return list(map(self._retrieve_futures_contract, sids))
def lookup_expired_futures(self, start, end):
start = start.value
end = end.value
fc_cols = self.futures_contracts.c
nd = sa.func.nullif(fc_cols.notice_date, pd.tslib.iNaT)
ed = sa.func.nullif(fc_cols.expiration_date, pd.tslib.iNaT)
date = sa.func.coalesce(sa.func.min(nd, ed), ed, nd)
sids = list(map(
itemgetter('sid'),
sa.select((fc_cols.sid,)).where(
(date >= start) & (date < end)).order_by(
sa.func.coalesce(ed, nd).asc()
).execute().fetchall()
))
return sids
@property
def sids(self):
return tuple(map(
@@ -741,6 +761,7 @@ class AssetFinderCachedEquities(AssetFinder):
into memory and overrides the methods that lookup_symbol uses to look up
those equities.
"""
def __init__(self, engine):
super(AssetFinderCachedEquities, self).__init__(engine)
self.fuzzy_symbol_hashed_equities = {}
+7 -8
View File
@@ -26,7 +26,6 @@ from zipline.protocol import (
SIDData,
DATASOURCE_TYPE
)
from zipline.errors import SidNotFound
log = Logger('Trade Simulation')
@@ -65,6 +64,7 @@ class AlgorithmSimulator(object):
# We don't have a datetime for the current snapshot until we
# receive a message.
self.simulation_dt = None
self.previous_dt = self.algo_start
# =============
# Logging Setup
@@ -97,18 +97,17 @@ class AlgorithmSimulator(object):
self._call_before_trading_start(mkt_open)
for date, snapshot in stream_in:
expired_sids = self.env.asset_finder.lookup_expired_futures(
start=self.previous_dt, end=date)
self.previous_dt = date
self.simulation_dt = date
self.on_dt_changed(date)
# removing expired futures
for sid in self.current_data.keys():
for sid in expired_sids:
try:
if self.env.asset_finder.retrieve_asset(sid).end_date \
< self.simulation_dt:
del self.current_data[sid]
except (AttributeError, TypeError, ValueError,
SidNotFound):
del self.current_data[sid]
except KeyError:
continue
# If we're still in the warmup period. Use the event to
+1 -1
View File
@@ -939,7 +939,7 @@ class InvalidOrderAlgorithm(TradingAlgorithm):
class TestRemoveDataAlgo(TradingAlgorithm):
def initialize(self, *args, **kwargs):
self.data = np.zeros(6)
self.data = np.zeros(7)
self.i = 0
def handle_data(self, data):