TST: Use testing market data with run_algorithm

so env doesn't need to download it
This commit is contained in:
Richard Frank
2017-05-17 18:58:50 -04:00
parent 3ca5a15859
commit 8734224701
7 changed files with 62 additions and 26 deletions
+17 -10
View File
@@ -88,6 +88,7 @@ from zipline.finance.asset_restrictions import (
)
from zipline.testing import (
FakeDataPortal,
copy_market_data,
create_daily_df_for_asset,
create_data_portal,
create_data_portal_from_trade_history,
@@ -99,6 +100,7 @@ from zipline.testing import (
tmp_trading_env,
to_utc,
trades_by_sid_to_dfs,
tmp_dir,
)
from zipline.testing import RecordBatchBlotter
from zipline.testing.fixtures import (
@@ -4760,13 +4762,18 @@ class TestPanelData(WithTradingEnvironment, ZiplineTestCase):
check_panels()
price_record.loc[:] = np.nan
run_algorithm(
start=start_dt,
end=end_dt,
capital_base=1,
initialize=initialize,
handle_data=handle_data,
data_frequency=data_frequency,
data=panel
)
check_panels()
with tmp_dir() as tmpdir:
root = tmpdir.getpath('example_data/root')
copy_market_data(self.MARKET_DATA_DIR, root)
run_algorithm(
start=start_dt,
end=end_dt,
capital_base=1,
initialize=initialize,
handle_data=handle_data,
data_frequency=data_frequency,
data=panel,
environ={'ZIPLINE_ROOT': root},
)
check_panels()
+6 -2
View File
@@ -21,8 +21,9 @@ import pandas as pd
from zipline import examples
from zipline.data.bundles import register, unregister
from zipline.testing import test_resource_path
from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase
from zipline.testing import test_resource_path, copy_market_data
from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase, \
WithTradingEnvironment
from zipline.testing.predicates import assert_equal
from zipline.utils.cache import dataframe_cache
@@ -53,6 +54,9 @@ class ExamplesTests(WithTmpDir, ZiplineTestCase):
serialization='pickle',
)
copy_market_data(WithTradingEnvironment.MARKET_DATA_DIR,
cls.tmpdir.getpath('example_data/root'))
@parameterized.expand(examples.EXAMPLE_MODULES)
def test_example(self, example_name):
actual_perf = examples.run_example(
+18 -11
View File
@@ -53,13 +53,13 @@ def last_modified_time(path):
return pd.Timestamp(os.path.getmtime(path), unit='s', tz='UTC')
def get_data_filepath(name):
def get_data_filepath(name, environ=None):
"""
Returns a handle to data file.
Creates containing directory, if needed.
"""
dr = data_root()
dr = data_root(environ)
if not os.path.exists(dr):
os.makedirs(dr)
@@ -91,7 +91,8 @@ def has_data_for_dates(series_or_df, first_date, last_date):
return (first <= first_date) and (last >= last_date)
def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC'):
def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC',
environ=None):
"""
Load benchmark returns and treasury yield curves for the given calendar and
benchmark symbol.
@@ -162,19 +163,22 @@ def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC'):
# We need the trading_day to figure out the close prior to the first
# date so that we can compute returns for the first date.
trading_day,
environ,
)
tc = ensure_treasury_data(
bm_symbol,
first_date,
last_date,
now,
environ,
)
benchmark_returns = br[br.index.slice_indexer(first_date, last_date)]
treasury_curves = tc[tc.index.slice_indexer(first_date, last_date)]
return benchmark_returns, treasury_curves
def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day,
environ=None):
"""
Ensure we have benchmark data for `symbol` from `first_date` to `last_date`
@@ -204,7 +208,8 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
path.
"""
filename = get_benchmark_filename(symbol)
data = _load_cached_data(filename, first_date, last_date, now, 'benchmark')
data = _load_cached_data(filename, first_date, last_date, now, 'benchmark',
environ)
if data is not None:
return data
@@ -218,7 +223,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
first_date - trading_day,
last_date,
)
data.to_csv(get_data_filepath(filename))
data.to_csv(get_data_filepath(filename, environ))
except (OSError, IOError, HTTPError):
logger.exception('failed to cache the new benchmark returns')
raise
@@ -227,7 +232,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
return data
def ensure_treasury_data(symbol, first_date, last_date, now):
def ensure_treasury_data(symbol, first_date, last_date, now, environ=None):
"""
Ensure we have treasury data from treasury module associated with
`symbol`.
@@ -259,7 +264,8 @@ def ensure_treasury_data(symbol, first_date, last_date, now):
)
first_date = max(first_date, loader_module.earliest_possible_date())
data = _load_cached_data(filename, first_date, last_date, now, 'treasury')
data = _load_cached_data(filename, first_date, last_date, now, 'treasury',
environ)
if data is not None:
return data
@@ -269,7 +275,7 @@ def ensure_treasury_data(symbol, first_date, last_date, now):
try:
data = loader_module.get_treasury_data(first_date, last_date)
data.to_csv(get_data_filepath(filename))
data.to_csv(get_data_filepath(filename, environ))
except (OSError, IOError, HTTPError):
logger.exception('failed to cache treasury data')
if not has_data_for_dates(data, first_date, last_date):
@@ -277,14 +283,15 @@ def ensure_treasury_data(symbol, first_date, last_date, now):
return data
def _load_cached_data(filename, first_date, last_date, now, resource_name):
def _load_cached_data(filename, first_date, last_date, now, resource_name,
environ=None):
if resource_name == 'benchmark':
from_csv = pd.Series.from_csv
else:
from_csv = pd.DataFrame.from_csv
# Path for the cache.
path = get_data_filepath(filename)
path = get_data_filepath(filename, environ)
# If the path does not exist, it means the first download has not happened
# yet, so don't try to read from 'path'.
+3 -1
View File
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import logbook
import pandas as pd
@@ -86,11 +87,12 @@ class TradingEnvironment(object):
trading_calendar=None,
asset_db_path=':memory:',
future_chain_predicates=CHAIN_PREDICATES,
environ=None,
):
self.bm_symbol = bm_symbol
if not load:
load = load_market_data
load = partial(load_market_data, environ=environ)
if not trading_calendar:
trading_calendar = get_calendar("NYSE")
+1
View File
@@ -16,6 +16,7 @@ from .core import ( # noqa
check_allclose,
check_arrays,
chrange,
copy_market_data,
create_daily_df_for_asset,
create_data_portal,
create_data_portal_from_trade_history,
+15
View File
@@ -29,6 +29,7 @@ from toolz import concat, curry
from zipline.assets import AssetFinder, AssetDBWriter
from zipline.assets.synthetic import make_simple_equity_info
from zipline.data.data_portal import DataPortal
from zipline.data.loader import get_benchmark_filename, INDEX_MAPPING
from zipline.data.minute_bars import (
BcolzMinuteBarReader,
BcolzMinuteBarWriter,
@@ -52,6 +53,7 @@ from zipline.utils.calendars import get_calendar
from zipline.utils.input_validation import expect_dimensions
from zipline.utils.numpy_utils import as_column, isnat
from zipline.utils.pandas_utils import timedelta_to_integral_seconds
from zipline.utils.paths import ensure_directory
from zipline.utils.sentinel import sentinel
import numpy as np
@@ -1490,6 +1492,19 @@ def patch_read_csv(url_map, module=pd, strict=False):
yield
def copy_market_data(src_market_data_dir, dest_root_dir):
symbol = '^GSPC'
filenames = (get_benchmark_filename(symbol), INDEX_MAPPING[symbol][1])
ensure_directory(os.path.join(dest_root_dir, 'data'))
for filename in filenames:
shutil.copyfile(
os.path.join(src_market_data_dir, filename),
os.path.join(dest_root_dir, 'data', filename)
)
@curry
def ensure_doctest(f, name=None):
"""Ensure that an object gets doctested. This is useful for instances
+2 -2
View File
@@ -129,7 +129,7 @@ def _run(handle_data,
"invalid url %r, must begin with 'sqlite:///'" %
str(bundle_data.asset_finder.engine.url),
)
env = TradingEnvironment(asset_db_path=connstr)
env = TradingEnvironment(asset_db_path=connstr, environ=environ)
first_trading_day =\
bundle_data.equity_minute_bar_reader.first_trading_day
data = DataPortal(
@@ -152,7 +152,7 @@ def _run(handle_data,
"No PipelineLoader registered for column %s." % column
)
else:
env = None
env = TradingEnvironment(environ=environ)
choose_loader = None
perf = TradingAlgorithm(