From 9d1e15dddee335e0f4c9b6473abde8637bde355f Mon Sep 17 00:00:00 2001 From: Jean Bredeche Date: Thu, 21 Apr 2016 13:03:18 -0400 Subject: [PATCH] BUG: Fetcher wasn't working properly in `before_trading_start`. We were trying to use the previous day in before_trading_start because we were looking for the previous market minute, then normalizing it. That's no longer the case, as we want to use today's date for fetcher lookups in before_trading_start. Also refactored a bit how dataportal determines if a query should be routed to the fetcher data structures. --- .../fetcher_inputs/fetcher_test_data.py | 9 +++ tests/test_fetcher.py | 43 ++++++++++++++ zipline/data/data_portal.py | 57 +++++++++++-------- zipline/testing/core.py | 2 +- 4 files changed, 85 insertions(+), 26 deletions(-) diff --git a/tests/resources/fetcher_inputs/fetcher_test_data.py b/tests/resources/fetcher_inputs/fetcher_test_data.py index fb5e9e1f..d0c775b6 100644 --- a/tests/resources/fetcher_inputs/fetcher_test_data.py +++ b/tests/resources/fetcher_inputs/fetcher_test_data.py @@ -134,6 +134,15 @@ Date,Value 2006-01-01,199.3 """.strip() +NFLX_DATA = """ +Settlement Date,symbol,dtc +7/31/13,NFLX,1.690317 +8/15/13,NFLX,2.811858 +8/30/13,NFLX,2.502331 +9/13/13,NFLX,2.550829 +9/30/13,NFLX,2.64484 +""" + PALLADIUM_DATA = """ Date,Hong Kong 8:30,Hong Kong 14:00,London 08:00,New York 9:30,New York 15:00 2007-12-31,367.0,367.0,368.0,368.0,368.0 diff --git a/tests/test_fetcher.py b/tests/test_fetcher.py index 71f77965..4a7596cb 100644 --- a/tests/test_fetcher.py +++ b/tests/test_fetcher.py @@ -39,6 +39,7 @@ from .resources.fetcher_inputs.fetcher_test_data import ( MULTI_SIGNAL_CSV_DATA, NON_ASSET_FETCHER_UNIVERSE_DATA, PALLADIUM_DATA, + NFLX_DATA ) @@ -84,6 +85,13 @@ class FetcherTestCase(WithResponses, 'symbol': 'DELL', 'asset_type': 'equity', 'exchange': 'nasdaq' + }, + 13: { + 'start_date': pd.Timestamp('2006-01-01', tz='UTC'), + 'end_date': pd.Timestamp('2010-01-01', tz='UTC'), + 'symbol': 'NFLX', + 'asset_type': 'equity', + 'exchange': 'nasdaq' } }, orient='index', @@ -552,3 +560,38 @@ def handle_data(context, data): self.assertEqual(3, results["sid_count"].iloc[0]) self.assertEqual(3, results["sid_count"].iloc[1]) self.assertEqual(4, results["sid_count"].iloc[2]) + + def test_fetcher_in_before_trading_start(self): + self.responses.add( + self.responses.GET, + 'https://fake.urls.com/fetcher_nflx_data.csv', + body=NFLX_DATA, + content_type='text/csv', + ) + + sim_params = factory.create_simulation_parameters( + start=pd.Timestamp("2013-06-13", tz='UTC'), + end=pd.Timestamp("2013-11-15", tz='UTC'), + data_frequency="minute" + ) + + results = self.run_algo(""" +from zipline.api import fetch_csv, record, symbol + +def initialize(context): + fetch_csv('https://fake.urls.com/fetcher_nflx_data.csv', + date_column = 'Settlement Date', + date_format = '%m/%d/%y') + context.stock = symbol('NFLX') + +def before_trading_start(context, data): + record(Short_Interest = data.current(context.stock, 'dtc')) +""", sim_params=sim_params, data_frequency="minute") + + values = results["Short_Interest"] + np.testing.assert_array_equal(values[0:33], np.full(33, np.nan)) + np.testing.assert_array_almost_equal(values[33:44], [1.690317] * 11) + np.testing.assert_array_almost_equal(values[44:55], [2.811858] * 11) + np.testing.assert_array_almost_equal(values[55:64], [2.50233] * 9) + np.testing.assert_array_almost_equal(values[64:75], [2.550829] * 11) + np.testing.assert_array_almost_equal(values[75:], [2.64484] * 35) diff --git a/zipline/data/data_portal.py b/zipline/data/data_portal.py index 232f11e8..d224460d 100644 --- a/zipline/data/data_portal.py +++ b/zipline/data/data_portal.py @@ -637,29 +637,17 @@ class DataPortal(object): elif data_frequency == 'daily': return self._equity_daily_reader.get_last_traded_dt(asset, dt) - def _check_extra_sources(self, asset, column, dt): + @staticmethod + def _is_extra_source(asset, field, map): + """ + Internal method that determines if this asset/field combination + represents a fetcher value or a regular OHLCVP lookup. + """ # If we have an extra source with a column called "price", only look # at it if it's on something like palladium and not AAPL (since our # own price data always wins when dealing with assets). - look_in_augmented_sources = column in self._augmented_sources_map and \ - not (column in BASE_FIELDS and isinstance(asset, Asset)) - - if look_in_augmented_sources: - day = normalize_date(dt) - - # we're being asked for a field in an extra source - try: - return self._augmented_sources_map[column][asset].\ - loc[day, column] - except: - log.error( - "Could not find value for asset={0}, day={1}," - "column={2}".format( - str(asset), - str(day), - str(column))) - - raise KeyError + return field in map and not (field in BASE_FIELDS and + isinstance(asset, Asset)) def get_spot_value(self, asset, field, dt, data_frequency): """ @@ -686,10 +674,21 @@ class DataPortal(object): ------- The value of the desired field at the desired time. """ - extra_source_val = self._check_extra_sources(asset, field, dt) + if self._is_extra_source(asset, field, self._augmented_sources_map): + day = normalize_date(dt) - if extra_source_val is not None: - return extra_source_val + try: + return \ + self._augmented_sources_map[field][asset].loc[day, field] + except: + log.error( + "Could not find value for asset={0}, day={1}," + "column={2}".format( + str(asset), + str(day), + str(field))) + + return np.NaN if field not in BASE_FIELDS: raise KeyError("Invalid column: " + str(field)) @@ -824,9 +823,17 @@ class DataPortal(object): ------- The value of the desired field at the desired time. """ - if spot_value is None: - spot_value = self.get_spot_value(asset, field, dt, data_frequency) + # if this a fetcher field, we want to use perspective_dt (not dt) + # because we want the new value as of midnight (fetcher only works + # on a daily basis, all timestamps are on midnight) + if self._is_extra_source(asset, field, + self._augmented_sources_map): + spot_value = self.get_spot_value(asset, field, perspective_dt, + data_frequency) + else: + spot_value = self.get_spot_value(asset, field, dt, + data_frequency) if isinstance(asset, Equity): ratio = self.get_adjustments(asset, field, dt, perspective_dt)[0] diff --git a/zipline/testing/core.py b/zipline/testing/core.py index f1109d75..a3dcae6e 100644 --- a/zipline/testing/core.py +++ b/zipline/testing/core.py @@ -705,7 +705,7 @@ class FetcherDataPortal(DataPortal): def get_spot_value(self, asset, field, dt, data_frequency): # if this is a fetcher field, exercise the regular code path - if self._check_extra_sources(asset, field, (dt or self.current_dt)): + if self._is_extra_source(asset, field, self._augmented_sources_map): return super(FetcherDataPortal, self).get_spot_value( asset, field, dt, data_frequency)