From 87342247016b14da96122d45700587e177efb053 Mon Sep 17 00:00:00 2001 From: Richard Frank Date: Wed, 17 May 2017 18:58:50 -0400 Subject: [PATCH] TST: Use testing market data with run_algorithm so env doesn't need to download it --- tests/test_algorithm.py | 27 +++++++++++++++++---------- tests/test_examples.py | 8 ++++++-- zipline/data/loader.py | 29 ++++++++++++++++++----------- zipline/finance/trading.py | 4 +++- zipline/testing/__init__.py | 1 + zipline/testing/core.py | 15 +++++++++++++++ zipline/utils/run_algo.py | 4 ++-- 7 files changed, 62 insertions(+), 26 deletions(-) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 2d3a243e..d9198ddc 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -88,6 +88,7 @@ from zipline.finance.asset_restrictions import ( ) from zipline.testing import ( FakeDataPortal, + copy_market_data, create_daily_df_for_asset, create_data_portal, create_data_portal_from_trade_history, @@ -99,6 +100,7 @@ from zipline.testing import ( tmp_trading_env, to_utc, trades_by_sid_to_dfs, + tmp_dir, ) from zipline.testing import RecordBatchBlotter from zipline.testing.fixtures import ( @@ -4760,13 +4762,18 @@ class TestPanelData(WithTradingEnvironment, ZiplineTestCase): check_panels() price_record.loc[:] = np.nan - run_algorithm( - start=start_dt, - end=end_dt, - capital_base=1, - initialize=initialize, - handle_data=handle_data, - data_frequency=data_frequency, - data=panel - ) - check_panels() + with tmp_dir() as tmpdir: + root = tmpdir.getpath('example_data/root') + copy_market_data(self.MARKET_DATA_DIR, root) + + run_algorithm( + start=start_dt, + end=end_dt, + capital_base=1, + initialize=initialize, + handle_data=handle_data, + data_frequency=data_frequency, + data=panel, + environ={'ZIPLINE_ROOT': root}, + ) + check_panels() diff --git a/tests/test_examples.py b/tests/test_examples.py index 2399064e..7e8a33ac 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -21,8 +21,9 @@ import pandas as pd from zipline import examples from zipline.data.bundles import register, unregister -from zipline.testing import test_resource_path -from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase +from zipline.testing import test_resource_path, copy_market_data +from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase, \ + WithTradingEnvironment from zipline.testing.predicates import assert_equal from zipline.utils.cache import dataframe_cache @@ -53,6 +54,9 @@ class ExamplesTests(WithTmpDir, ZiplineTestCase): serialization='pickle', ) + copy_market_data(WithTradingEnvironment.MARKET_DATA_DIR, + cls.tmpdir.getpath('example_data/root')) + @parameterized.expand(examples.EXAMPLE_MODULES) def test_example(self, example_name): actual_perf = examples.run_example( diff --git a/zipline/data/loader.py b/zipline/data/loader.py index 97b8c825..821d797f 100644 --- a/zipline/data/loader.py +++ b/zipline/data/loader.py @@ -53,13 +53,13 @@ def last_modified_time(path): return pd.Timestamp(os.path.getmtime(path), unit='s', tz='UTC') -def get_data_filepath(name): +def get_data_filepath(name, environ=None): """ Returns a handle to data file. Creates containing directory, if needed. """ - dr = data_root() + dr = data_root(environ) if not os.path.exists(dr): os.makedirs(dr) @@ -91,7 +91,8 @@ def has_data_for_dates(series_or_df, first_date, last_date): return (first <= first_date) and (last >= last_date) -def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC'): +def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC', + environ=None): """ Load benchmark returns and treasury yield curves for the given calendar and benchmark symbol. @@ -162,19 +163,22 @@ def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC'): # We need the trading_day to figure out the close prior to the first # date so that we can compute returns for the first date. trading_day, + environ, ) tc = ensure_treasury_data( bm_symbol, first_date, last_date, now, + environ, ) benchmark_returns = br[br.index.slice_indexer(first_date, last_date)] treasury_curves = tc[tc.index.slice_indexer(first_date, last_date)] return benchmark_returns, treasury_curves -def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day): +def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day, + environ=None): """ Ensure we have benchmark data for `symbol` from `first_date` to `last_date` @@ -204,7 +208,8 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day): path. """ filename = get_benchmark_filename(symbol) - data = _load_cached_data(filename, first_date, last_date, now, 'benchmark') + data = _load_cached_data(filename, first_date, last_date, now, 'benchmark', + environ) if data is not None: return data @@ -218,7 +223,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day): first_date - trading_day, last_date, ) - data.to_csv(get_data_filepath(filename)) + data.to_csv(get_data_filepath(filename, environ)) except (OSError, IOError, HTTPError): logger.exception('failed to cache the new benchmark returns') raise @@ -227,7 +232,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day): return data -def ensure_treasury_data(symbol, first_date, last_date, now): +def ensure_treasury_data(symbol, first_date, last_date, now, environ=None): """ Ensure we have treasury data from treasury module associated with `symbol`. @@ -259,7 +264,8 @@ def ensure_treasury_data(symbol, first_date, last_date, now): ) first_date = max(first_date, loader_module.earliest_possible_date()) - data = _load_cached_data(filename, first_date, last_date, now, 'treasury') + data = _load_cached_data(filename, first_date, last_date, now, 'treasury', + environ) if data is not None: return data @@ -269,7 +275,7 @@ def ensure_treasury_data(symbol, first_date, last_date, now): try: data = loader_module.get_treasury_data(first_date, last_date) - data.to_csv(get_data_filepath(filename)) + data.to_csv(get_data_filepath(filename, environ)) except (OSError, IOError, HTTPError): logger.exception('failed to cache treasury data') if not has_data_for_dates(data, first_date, last_date): @@ -277,14 +283,15 @@ def ensure_treasury_data(symbol, first_date, last_date, now): return data -def _load_cached_data(filename, first_date, last_date, now, resource_name): +def _load_cached_data(filename, first_date, last_date, now, resource_name, + environ=None): if resource_name == 'benchmark': from_csv = pd.Series.from_csv else: from_csv = pd.DataFrame.from_csv # Path for the cache. - path = get_data_filepath(filename) + path = get_data_filepath(filename, environ) # If the path does not exist, it means the first download has not happened # yet, so don't try to read from 'path'. diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index 78195d25..fb2aa069 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import logbook import pandas as pd @@ -86,11 +87,12 @@ class TradingEnvironment(object): trading_calendar=None, asset_db_path=':memory:', future_chain_predicates=CHAIN_PREDICATES, + environ=None, ): self.bm_symbol = bm_symbol if not load: - load = load_market_data + load = partial(load_market_data, environ=environ) if not trading_calendar: trading_calendar = get_calendar("NYSE") diff --git a/zipline/testing/__init__.py b/zipline/testing/__init__.py index 1a5449b0..3911aacf 100644 --- a/zipline/testing/__init__.py +++ b/zipline/testing/__init__.py @@ -16,6 +16,7 @@ from .core import ( # noqa check_allclose, check_arrays, chrange, + copy_market_data, create_daily_df_for_asset, create_data_portal, create_data_portal_from_trade_history, diff --git a/zipline/testing/core.py b/zipline/testing/core.py index 695ece7e..cbcb56e5 100644 --- a/zipline/testing/core.py +++ b/zipline/testing/core.py @@ -29,6 +29,7 @@ from toolz import concat, curry from zipline.assets import AssetFinder, AssetDBWriter from zipline.assets.synthetic import make_simple_equity_info from zipline.data.data_portal import DataPortal +from zipline.data.loader import get_benchmark_filename, INDEX_MAPPING from zipline.data.minute_bars import ( BcolzMinuteBarReader, BcolzMinuteBarWriter, @@ -52,6 +53,7 @@ from zipline.utils.calendars import get_calendar from zipline.utils.input_validation import expect_dimensions from zipline.utils.numpy_utils import as_column, isnat from zipline.utils.pandas_utils import timedelta_to_integral_seconds +from zipline.utils.paths import ensure_directory from zipline.utils.sentinel import sentinel import numpy as np @@ -1490,6 +1492,19 @@ def patch_read_csv(url_map, module=pd, strict=False): yield +def copy_market_data(src_market_data_dir, dest_root_dir): + symbol = '^GSPC' + filenames = (get_benchmark_filename(symbol), INDEX_MAPPING[symbol][1]) + + ensure_directory(os.path.join(dest_root_dir, 'data')) + + for filename in filenames: + shutil.copyfile( + os.path.join(src_market_data_dir, filename), + os.path.join(dest_root_dir, 'data', filename) + ) + + @curry def ensure_doctest(f, name=None): """Ensure that an object gets doctested. This is useful for instances diff --git a/zipline/utils/run_algo.py b/zipline/utils/run_algo.py index 2d3b2e5b..639707fa 100644 --- a/zipline/utils/run_algo.py +++ b/zipline/utils/run_algo.py @@ -129,7 +129,7 @@ def _run(handle_data, "invalid url %r, must begin with 'sqlite:///'" % str(bundle_data.asset_finder.engine.url), ) - env = TradingEnvironment(asset_db_path=connstr) + env = TradingEnvironment(asset_db_path=connstr, environ=environ) first_trading_day =\ bundle_data.equity_minute_bar_reader.first_trading_day data = DataPortal( @@ -152,7 +152,7 @@ def _run(handle_data, "No PipelineLoader registered for column %s." % column ) else: - env = None + env = TradingEnvironment(environ=environ) choose_loader = None perf = TradingAlgorithm(