mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 16:35:02 +08:00
ENH: adds lookup_expired_futures to asset_finder
This commit is contained in:
+14
-13
@@ -1973,22 +1973,24 @@ class TestRemoveData(TestCase):
|
||||
tests if futures data is removed after expiry
|
||||
"""
|
||||
def setUp(self):
|
||||
dt = pd.Timestamp('2015-01-01', tz='UTC')
|
||||
metadata = {0: {'symbol': 'X',
|
||||
'expiration_date': dt + timedelta(days=5),
|
||||
'end_date': dt + timedelta(days=5)},
|
||||
1: {'symbol': 'Y',
|
||||
'expiration_date': dt + timedelta(days=7),
|
||||
'end_date': dt + timedelta(days=7)}}
|
||||
|
||||
dt = pd.Timestamp('2015-01-02', tz='UTC')
|
||||
env = TradingEnvironment()
|
||||
ix = env.trading_days.get_loc(dt)
|
||||
|
||||
metadata = {0: {'symbol': 'X',
|
||||
'expiration_date': env.trading_days[ix + 5],
|
||||
'end_date': env.trading_days[ix + 6]},
|
||||
1: {'symbol': 'Y',
|
||||
'expiration_date': env.trading_days[ix + 7],
|
||||
'end_date': env.trading_days[ix + 8]}}
|
||||
|
||||
env.write_data(futures_data=metadata)
|
||||
|
||||
index_x = pd.date_range(dt, periods=5)
|
||||
index_x = env.trading_days[ix:ix + 5]
|
||||
data_x = pd.DataFrame([[1, 100], [2, 100], [3, 100], [4, 100],
|
||||
[5, 100]],
|
||||
index=index_x, columns=['price', 'volume'])
|
||||
index_y = index_x.shift(2)
|
||||
index_y = env.trading_days[ix:ix + 5].shift(2)
|
||||
data_y = pd.DataFrame([[6, 100], [7, 100], [8, 100], [9, 100],
|
||||
[10, 100]],
|
||||
index=index_y, columns=['price', 'volume'])
|
||||
@@ -2000,8 +2002,7 @@ class TestRemoveData(TestCase):
|
||||
def test_remove_data(self):
|
||||
self.algo.run(self.source)
|
||||
|
||||
expected_length = [1, 2, 2, 2, 2, 1]
|
||||
expected_lengths = [1, 1, 2, 2, 2, 2, 1]
|
||||
# initially only data for X should be sent and on the last day only
|
||||
# data for Y should be sent since X is expired
|
||||
for i, length in enumerate(self.algo.data):
|
||||
self.assertEqual(expected_length[i], length, i)
|
||||
np.testing.assert_array_equal(self.algo.data, expected_lengths)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import unittest
|
||||
import datetime
|
||||
import pandas as pd
|
||||
import pytz
|
||||
|
||||
import numpy as np
|
||||
@@ -77,9 +78,9 @@ class TestEventsThroughRisk(unittest.TestCase):
|
||||
|
||||
algo = BuyAndHoldAlgorithm(sim_params=sim_params, env=self.env)
|
||||
|
||||
first_date = datetime.datetime(2006, 1, 3, tzinfo=pytz.utc)
|
||||
second_date = datetime.datetime(2006, 1, 4, tzinfo=pytz.utc)
|
||||
third_date = datetime.datetime(2006, 1, 5, tzinfo=pytz.utc)
|
||||
first_date = pd.Timestamp('2006-01-03', tz='UTC')
|
||||
second_date = pd.Timestamp('2006-01-04', tz='UTC')
|
||||
third_date = pd.Timestamp('2006-01-05', tz='UTC')
|
||||
|
||||
trade_bar_data = [
|
||||
Event({
|
||||
|
||||
@@ -123,12 +123,7 @@ class TestDataFrameSource(TestCase):
|
||||
self.assertEqual(5, event.sid)
|
||||
event = next(source)
|
||||
self.assertEqual(4, event.sid)
|
||||
try:
|
||||
x = False
|
||||
event = next(source)
|
||||
except StopIteration:
|
||||
x = True
|
||||
self.assertTrue(x)
|
||||
self.assertRaises(StopIteration, next, source)
|
||||
|
||||
|
||||
class TestRandomWalkSource(TestCase):
|
||||
|
||||
@@ -266,9 +266,9 @@ class AssetFinder(object):
|
||||
self.equities.c.share_class_symbol ==
|
||||
share_class_symbol,
|
||||
self.equities.c.start_date <= ad_value),
|
||||
).order_by(
|
||||
self.equities.c.end_date.desc(),
|
||||
).execute().fetchall()
|
||||
).order_by(
|
||||
self.equities.c.end_date.desc(),
|
||||
).execute().fetchall()
|
||||
return candidates
|
||||
|
||||
def _get_best_candidate(self, candidates):
|
||||
@@ -492,6 +492,26 @@ class AssetFinder(object):
|
||||
|
||||
return list(map(self._retrieve_futures_contract, sids))
|
||||
|
||||
def lookup_expired_futures(self, start, end):
|
||||
start = start.value
|
||||
end = end.value
|
||||
|
||||
fc_cols = self.futures_contracts.c
|
||||
|
||||
nd = sa.func.nullif(fc_cols.notice_date, pd.tslib.iNaT)
|
||||
ed = sa.func.nullif(fc_cols.expiration_date, pd.tslib.iNaT)
|
||||
date = sa.func.coalesce(sa.func.min(nd, ed), ed, nd)
|
||||
|
||||
sids = list(map(
|
||||
itemgetter('sid'),
|
||||
sa.select((fc_cols.sid,)).where(
|
||||
(date >= start) & (date < end)).order_by(
|
||||
sa.func.coalesce(ed, nd).asc()
|
||||
).execute().fetchall()
|
||||
))
|
||||
|
||||
return sids
|
||||
|
||||
@property
|
||||
def sids(self):
|
||||
return tuple(map(
|
||||
@@ -741,6 +761,7 @@ class AssetFinderCachedEquities(AssetFinder):
|
||||
into memory and overrides the methods that lookup_symbol uses to look up
|
||||
those equities.
|
||||
"""
|
||||
|
||||
def __init__(self, engine):
|
||||
super(AssetFinderCachedEquities, self).__init__(engine)
|
||||
self.fuzzy_symbol_hashed_equities = {}
|
||||
|
||||
@@ -26,7 +26,6 @@ from zipline.protocol import (
|
||||
SIDData,
|
||||
DATASOURCE_TYPE
|
||||
)
|
||||
from zipline.errors import SidNotFound
|
||||
|
||||
log = Logger('Trade Simulation')
|
||||
|
||||
@@ -65,6 +64,7 @@ class AlgorithmSimulator(object):
|
||||
# We don't have a datetime for the current snapshot until we
|
||||
# receive a message.
|
||||
self.simulation_dt = None
|
||||
self.previous_dt = self.algo_start
|
||||
|
||||
# =============
|
||||
# Logging Setup
|
||||
@@ -97,18 +97,17 @@ class AlgorithmSimulator(object):
|
||||
self._call_before_trading_start(mkt_open)
|
||||
|
||||
for date, snapshot in stream_in:
|
||||
|
||||
expired_sids = self.env.asset_finder.lookup_expired_futures(
|
||||
start=self.previous_dt, end=date)
|
||||
self.previous_dt = date
|
||||
self.simulation_dt = date
|
||||
self.on_dt_changed(date)
|
||||
|
||||
# removing expired futures
|
||||
for sid in self.current_data.keys():
|
||||
for sid in expired_sids:
|
||||
try:
|
||||
if self.env.asset_finder.retrieve_asset(sid).end_date \
|
||||
< self.simulation_dt:
|
||||
del self.current_data[sid]
|
||||
except (AttributeError, TypeError, ValueError,
|
||||
SidNotFound):
|
||||
del self.current_data[sid]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
# If we're still in the warmup period. Use the event to
|
||||
|
||||
@@ -939,7 +939,7 @@ class InvalidOrderAlgorithm(TradingAlgorithm):
|
||||
|
||||
class TestRemoveDataAlgo(TradingAlgorithm):
|
||||
def initialize(self, *args, **kwargs):
|
||||
self.data = np.zeros(6)
|
||||
self.data = np.zeros(7)
|
||||
self.i = 0
|
||||
|
||||
def handle_data(self, data):
|
||||
|
||||
Reference in New Issue
Block a user