mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 20:20:55 +08:00
BUG: Fetcher wasn't working properly in before_trading_start.
We were trying to use the previous day in before_trading_start because we were looking for the previous market minute, then normalizing it. That's no longer the case, as we want to use today's date for fetcher lookups in before_trading_start. Also refactored a bit how dataportal determines if a query should be routed to the fetcher data structures.
This commit is contained in:
@@ -134,6 +134,15 @@ Date,Value
|
||||
2006-01-01,199.3
|
||||
""".strip()
|
||||
|
||||
NFLX_DATA = """
|
||||
Settlement Date,symbol,dtc
|
||||
7/31/13,NFLX,1.690317
|
||||
8/15/13,NFLX,2.811858
|
||||
8/30/13,NFLX,2.502331
|
||||
9/13/13,NFLX,2.550829
|
||||
9/30/13,NFLX,2.64484
|
||||
"""
|
||||
|
||||
PALLADIUM_DATA = """
|
||||
Date,Hong Kong 8:30,Hong Kong 14:00,London 08:00,New York 9:30,New York 15:00
|
||||
2007-12-31,367.0,367.0,368.0,368.0,368.0
|
||||
|
||||
@@ -39,6 +39,7 @@ from .resources.fetcher_inputs.fetcher_test_data import (
|
||||
MULTI_SIGNAL_CSV_DATA,
|
||||
NON_ASSET_FETCHER_UNIVERSE_DATA,
|
||||
PALLADIUM_DATA,
|
||||
NFLX_DATA
|
||||
)
|
||||
|
||||
|
||||
@@ -84,6 +85,13 @@ class FetcherTestCase(WithResponses,
|
||||
'symbol': 'DELL',
|
||||
'asset_type': 'equity',
|
||||
'exchange': 'nasdaq'
|
||||
},
|
||||
13: {
|
||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||
'end_date': pd.Timestamp('2010-01-01', tz='UTC'),
|
||||
'symbol': 'NFLX',
|
||||
'asset_type': 'equity',
|
||||
'exchange': 'nasdaq'
|
||||
}
|
||||
},
|
||||
orient='index',
|
||||
@@ -552,3 +560,38 @@ def handle_data(context, data):
|
||||
self.assertEqual(3, results["sid_count"].iloc[0])
|
||||
self.assertEqual(3, results["sid_count"].iloc[1])
|
||||
self.assertEqual(4, results["sid_count"].iloc[2])
|
||||
|
||||
def test_fetcher_in_before_trading_start(self):
|
||||
self.responses.add(
|
||||
self.responses.GET,
|
||||
'https://fake.urls.com/fetcher_nflx_data.csv',
|
||||
body=NFLX_DATA,
|
||||
content_type='text/csv',
|
||||
)
|
||||
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
start=pd.Timestamp("2013-06-13", tz='UTC'),
|
||||
end=pd.Timestamp("2013-11-15", tz='UTC'),
|
||||
data_frequency="minute"
|
||||
)
|
||||
|
||||
results = self.run_algo("""
|
||||
from zipline.api import fetch_csv, record, symbol
|
||||
|
||||
def initialize(context):
|
||||
fetch_csv('https://fake.urls.com/fetcher_nflx_data.csv',
|
||||
date_column = 'Settlement Date',
|
||||
date_format = '%m/%d/%y')
|
||||
context.stock = symbol('NFLX')
|
||||
|
||||
def before_trading_start(context, data):
|
||||
record(Short_Interest = data.current(context.stock, 'dtc'))
|
||||
""", sim_params=sim_params, data_frequency="minute")
|
||||
|
||||
values = results["Short_Interest"]
|
||||
np.testing.assert_array_equal(values[0:33], np.full(33, np.nan))
|
||||
np.testing.assert_array_almost_equal(values[33:44], [1.690317] * 11)
|
||||
np.testing.assert_array_almost_equal(values[44:55], [2.811858] * 11)
|
||||
np.testing.assert_array_almost_equal(values[55:64], [2.50233] * 9)
|
||||
np.testing.assert_array_almost_equal(values[64:75], [2.550829] * 11)
|
||||
np.testing.assert_array_almost_equal(values[75:], [2.64484] * 35)
|
||||
|
||||
+32
-25
@@ -637,29 +637,17 @@ class DataPortal(object):
|
||||
elif data_frequency == 'daily':
|
||||
return self._equity_daily_reader.get_last_traded_dt(asset, dt)
|
||||
|
||||
def _check_extra_sources(self, asset, column, dt):
|
||||
@staticmethod
|
||||
def _is_extra_source(asset, field, map):
|
||||
"""
|
||||
Internal method that determines if this asset/field combination
|
||||
represents a fetcher value or a regular OHLCVP lookup.
|
||||
"""
|
||||
# If we have an extra source with a column called "price", only look
|
||||
# at it if it's on something like palladium and not AAPL (since our
|
||||
# own price data always wins when dealing with assets).
|
||||
look_in_augmented_sources = column in self._augmented_sources_map and \
|
||||
not (column in BASE_FIELDS and isinstance(asset, Asset))
|
||||
|
||||
if look_in_augmented_sources:
|
||||
day = normalize_date(dt)
|
||||
|
||||
# we're being asked for a field in an extra source
|
||||
try:
|
||||
return self._augmented_sources_map[column][asset].\
|
||||
loc[day, column]
|
||||
except:
|
||||
log.error(
|
||||
"Could not find value for asset={0}, day={1},"
|
||||
"column={2}".format(
|
||||
str(asset),
|
||||
str(day),
|
||||
str(column)))
|
||||
|
||||
raise KeyError
|
||||
return field in map and not (field in BASE_FIELDS and
|
||||
isinstance(asset, Asset))
|
||||
|
||||
def get_spot_value(self, asset, field, dt, data_frequency):
|
||||
"""
|
||||
@@ -686,10 +674,21 @@ class DataPortal(object):
|
||||
-------
|
||||
The value of the desired field at the desired time.
|
||||
"""
|
||||
extra_source_val = self._check_extra_sources(asset, field, dt)
|
||||
if self._is_extra_source(asset, field, self._augmented_sources_map):
|
||||
day = normalize_date(dt)
|
||||
|
||||
if extra_source_val is not None:
|
||||
return extra_source_val
|
||||
try:
|
||||
return \
|
||||
self._augmented_sources_map[field][asset].loc[day, field]
|
||||
except:
|
||||
log.error(
|
||||
"Could not find value for asset={0}, day={1},"
|
||||
"column={2}".format(
|
||||
str(asset),
|
||||
str(day),
|
||||
str(field)))
|
||||
|
||||
return np.NaN
|
||||
|
||||
if field not in BASE_FIELDS:
|
||||
raise KeyError("Invalid column: " + str(field))
|
||||
@@ -824,9 +823,17 @@ class DataPortal(object):
|
||||
-------
|
||||
The value of the desired field at the desired time.
|
||||
"""
|
||||
|
||||
if spot_value is None:
|
||||
spot_value = self.get_spot_value(asset, field, dt, data_frequency)
|
||||
# if this a fetcher field, we want to use perspective_dt (not dt)
|
||||
# because we want the new value as of midnight (fetcher only works
|
||||
# on a daily basis, all timestamps are on midnight)
|
||||
if self._is_extra_source(asset, field,
|
||||
self._augmented_sources_map):
|
||||
spot_value = self.get_spot_value(asset, field, perspective_dt,
|
||||
data_frequency)
|
||||
else:
|
||||
spot_value = self.get_spot_value(asset, field, dt,
|
||||
data_frequency)
|
||||
|
||||
if isinstance(asset, Equity):
|
||||
ratio = self.get_adjustments(asset, field, dt, perspective_dt)[0]
|
||||
|
||||
@@ -705,7 +705,7 @@ class FetcherDataPortal(DataPortal):
|
||||
|
||||
def get_spot_value(self, asset, field, dt, data_frequency):
|
||||
# if this is a fetcher field, exercise the regular code path
|
||||
if self._check_extra_sources(asset, field, (dt or self.current_dt)):
|
||||
if self._is_extra_source(asset, field, self._augmented_sources_map):
|
||||
return super(FetcherDataPortal, self).get_spot_value(
|
||||
asset, field, dt, data_frequency)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user