diff --git a/tests/pipeline/test_pipeline_algo.py b/tests/pipeline/test_pipeline_algo.py index b45f7742..4721236e 100644 --- a/tests/pipeline/test_pipeline_algo.py +++ b/tests/pipeline/test_pipeline_algo.py @@ -566,7 +566,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs, ) algo.run( - FakeDataPortal(), + FakeDataPortal(self.env), # Yes, I really do want to use the start and end dates I passed to # TradingAlgorithm. overwrite_sim_params=False, @@ -606,7 +606,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs, ) algo.run( - FakeDataPortal(), + FakeDataPortal(self.env), overwrite_sim_params=False, ) @@ -654,7 +654,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs, ) algo.run( - FakeDataPortal(), + FakeDataPortal(self.env), overwrite_sim_params=False, ) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 4c89a9b8..d9198ddc 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -88,11 +88,11 @@ from zipline.finance.asset_restrictions import ( ) from zipline.testing import ( FakeDataPortal, + copy_market_data, create_daily_df_for_asset, create_data_portal, create_data_portal_from_trade_history, create_minute_df_for_asset, - empty_trading_env, make_test_handler, make_trade_data_for_asset_info, parameter_space, @@ -100,6 +100,7 @@ from zipline.testing import ( tmp_trading_env, to_utc, trades_by_sid_to_dfs, + tmp_dir, ) from zipline.testing import RecordBatchBlotter from zipline.testing.fixtures import ( @@ -108,7 +109,6 @@ from zipline.testing.fixtures import ( WithSimParams, WithTradingEnvironment, WithTmpDir, - WithTradingCalendars, ZiplineTestCase, ) from zipline.test_algorithms import ( @@ -314,6 +314,7 @@ def handle_data(algo, data): initialize=lambda context: None, handle_data=lambda context, data: None, sim_params=self.sim_params, + env=self.env, ) # Verify that api methods get resolved dynamically by patching them out @@ -785,7 +786,8 @@ def log_nyse_close(context, data): for i, date in enumerate(dates) ] ) - with tmp_trading_env(equities=metadata) as env: + with tmp_trading_env(equities=metadata, + load=self.make_load_function()) as env: algo = TradingAlgorithm(env=env) # Set the period end to a date after the period end @@ -852,7 +854,8 @@ class TestTransformAlgorithm(WithLogger, def init_class_fixtures(cls): super(TestTransformAlgorithm, cls).init_class_fixtures() cls.futures_env = cls.enter_class_context( - tmp_trading_env(futures=cls.make_futures_info()), + tmp_trading_env(futures=cls.make_futures_info(), + load=cls.make_load_function()), ) def test_invalid_order_parameters(self): @@ -892,7 +895,8 @@ def before_trading_start(context, data): def test_run_twice(self): algo1 = TestRegisterTransformAlgorithm( sim_params=self.sim_params, - sids=[0, 1] + sids=[0, 1], + env=self.env, ) res1 = algo1.run(self.data_portal) @@ -901,7 +905,8 @@ def before_trading_start(context, data): # use the newly instantiated environment. algo2 = TestRegisterTransformAlgorithm( sim_params=self.sim_params, - sids=[0, 1] + sids=[0, 1], + env=self.env, ) res2 = algo2.run(self.data_portal) @@ -1062,7 +1067,8 @@ def before_trading_start(context, data): }] * 2) equities['symbol'] = ['A', 'B'] with TempDirectory() as tempdir, \ - tmp_trading_env(equities=equities) as env: + tmp_trading_env(equities=equities, + load=self.make_load_function()) as env: sim_params = SimulationParameters( start_session=start_session, end_session=period_end, @@ -1569,15 +1575,16 @@ class TestAlgoScript(WithLogger, def test_noop(self): algo = TradingAlgorithm(initialize=initialize_noop, - handle_data=handle_data_noop) + handle_data=handle_data_noop, + env=self.env) algo.run(self.data_portal) def test_noop_string(self): - algo = TradingAlgorithm(script=noop_algo) + algo = TradingAlgorithm(script=noop_algo, env=self.env) algo.run(self.data_portal) def test_no_handle_data(self): - algo = TradingAlgorithm(script=no_handle_data) + algo = TradingAlgorithm(script=no_handle_data, env=self.env) algo.run(self.data_portal) def test_api_calls(self): @@ -1593,7 +1600,8 @@ class TestAlgoScript(WithLogger, def test_api_get_environment(self): platform = 'zipline' algo = TradingAlgorithm(script=api_get_environment_algo, - platform=platform) + platform=platform, + env=self.env) algo.run(self.data_portal) self.assertEqual(algo.environment, platform) @@ -1779,6 +1787,7 @@ def handle_data(context, data): test_algo = TradingAlgorithm( script=record_variables, sim_params=self.sim_params, + env=self.env, ) set_algo_instance(test_algo) @@ -3169,7 +3178,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase): orient='index', ) with TempDirectory() as tempdir, \ - tmp_trading_env(equities=metadata) as env: + tmp_trading_env(equities=metadata, + load=self.make_load_function()) as env: sim_params = factory.create_simulation_parameters( start=start, num_days=4, @@ -3296,7 +3306,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase): 'sid': 999, }]) with TempDirectory() as tempdir, \ - tmp_trading_env(equities=metadata) as env: + tmp_trading_env(equities=metadata, + load=self.make_load_function()) as env: algo = SetAssetDateBoundsAlgorithm( sim_params=self.sim_params, env=env, @@ -3318,7 +3329,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase): 'sid': 999, }]) with TempDirectory() as tempdir, \ - tmp_trading_env(equities=metadata) as env: + tmp_trading_env(equities=metadata, + load=self.make_load_function()) as env: data_portal = create_data_portal( env.asset_finder, tempdir, @@ -3341,7 +3353,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase): 'sid': 999, }]) with TempDirectory() as tempdir, \ - tmp_trading_env(equities=metadata) as env: + tmp_trading_env(equities=metadata, + load=self.make_load_function()) as env: data_portal = create_data_portal( env.asset_finder, tempdir, @@ -3768,7 +3781,7 @@ class TestFuturesAlgo(WithDataPortal, WithSimParams, ZiplineTestCase): self.assertEqual(txn['price'], expected_price) -class TestTradingAlgorithm(ZiplineTestCase): +class TestTradingAlgorithm(WithTradingEnvironment, ZiplineTestCase): def test_analyze_called(self): self.perf_ref = None @@ -3785,11 +3798,11 @@ class TestTradingAlgorithm(ZiplineTestCase): initialize=initialize, handle_data=handle_data, analyze=analyze, + env=self.env, ) - with empty_trading_env() as env: - data_portal = FakeDataPortal(env) - results = algo.run(data_portal) + data_portal = FakeDataPortal(self.env) + results = algo.run(data_portal) self.assertIs(results, self.perf_ref) @@ -3989,7 +4002,7 @@ class TestOrderCancelation(WithDataPortal, self.assertFalse(log_catcher.has_warnings) -class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase): +class TestEquityAutoClose(WithTradingEnvironment, WithTmpDir, ZiplineTestCase): """ Tests if delisted equities are properly removed from a portfolio holding positions in said equities. @@ -4020,7 +4033,10 @@ class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase): sids = asset_info.index - env = self.enter_instance_context(tmp_trading_env(equities=asset_info)) + env = self.enter_instance_context( + tmp_trading_env(equities=asset_info, + load=self.make_load_function()) + ) if frequency == 'daily': dates = self.test_days @@ -4642,7 +4658,7 @@ class TestOrderAfterDelist(WithTradingEnvironment, ZiplineTestCase): self.assertEqual(expected_message, w.message) -class AlgoInputValidationTestCase(ZiplineTestCase): +class AlgoInputValidationTestCase(WithTradingEnvironment, ZiplineTestCase): def test_reject_passing_both_api_methods_and_script(self): script = dedent( @@ -4668,11 +4684,12 @@ class AlgoInputValidationTestCase(ZiplineTestCase): with self.assertRaises(ValueError): TradingAlgorithm( script=script, + env=self.env, **{method: lambda *args, **kwargs: None} ) -class TestPanelData(ZiplineTestCase): +class TestPanelData(WithTradingEnvironment, ZiplineTestCase): @parameterized.expand([ ('daily', @@ -4694,6 +4711,9 @@ class TestPanelData(ZiplineTestCase): def dt_transform(dt): return dt + else: + raise AssertionError('Unexpected data_frequency: %s' % + data_frequency) sids = range(1, 3) dfs = {} @@ -4734,19 +4754,26 @@ class TestPanelData(ZiplineTestCase): 'prev_close']].values.astype('float64') ) - trading_algo = TradingAlgorithm(initialize=initialize, - handle_data=handle_data) - trading_algo.run(data=panel) - check_panels() - price_record.loc[:] = np.nan + with tmp_trading_env(load=self.make_load_function()) as env: + trading_algo = TradingAlgorithm(initialize=initialize, + handle_data=handle_data, + env=env) + trading_algo.run(data=panel) + check_panels() + price_record.loc[:] = np.nan - run_algorithm( - start=start_dt, - end=end_dt, - capital_base=1, - initialize=initialize, - handle_data=handle_data, - data_frequency=data_frequency, - data=panel - ) - check_panels() + with tmp_dir() as tmpdir: + root = tmpdir.getpath('example_data/root') + copy_market_data(self.MARKET_DATA_DIR, root) + + run_algorithm( + start=start_dt, + end=end_dt, + capital_base=1, + initialize=initialize, + handle_data=handle_data, + data_frequency=data_frequency, + data=panel, + environ={'ZIPLINE_ROOT': root}, + ) + check_panels() diff --git a/tests/test_examples.py b/tests/test_examples.py index 2399064e..7e8a33ac 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -21,8 +21,9 @@ import pandas as pd from zipline import examples from zipline.data.bundles import register, unregister -from zipline.testing import test_resource_path -from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase +from zipline.testing import test_resource_path, copy_market_data +from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase, \ + WithTradingEnvironment from zipline.testing.predicates import assert_equal from zipline.utils.cache import dataframe_cache @@ -53,6 +54,9 @@ class ExamplesTests(WithTmpDir, ZiplineTestCase): serialization='pickle', ) + copy_market_data(WithTradingEnvironment.MARKET_DATA_DIR, + cls.tmpdir.getpath('example_data/root')) + @parameterized.expand(examples.EXAMPLE_MODULES) def test_example(self, example_name): actual_perf = examples.run_example( diff --git a/tests/test_finance.py b/tests/test_finance.py index 40667326..42c5b872 100644 --- a/tests/test_finance.py +++ b/tests/test_finance.py @@ -190,7 +190,8 @@ class FinanceTestCase(WithLogger, asset1 = self.asset_finder.retrieve_asset(1) metadata = make_simple_equity_info([asset1.sid], self.start, self.end) with TempDirectory() as tempdir, \ - tmp_trading_env(equities=metadata) as env: + tmp_trading_env(equities=metadata, + load=self.make_load_function()) as env: if trade_interval < timedelta(days=1): sim_params = factory.create_simulation_parameters( diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index 53f52367..ebc068ee 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -57,7 +57,6 @@ from zipline.testing.fixtures import ( WithSimParams, WithTmpDir, WithTradingEnvironment, - WithTradingCalendars, ZiplineTestCase, ) from zipline.utils.calendars import get_calendar @@ -1029,7 +1028,8 @@ class TestDividendPerformanceHolidayStyle(TestDividendPerformance): END_DATE = pd.Timestamp('2003-12-08', tz='utc') -class TestPositionPerformance(WithInstanceTmpDir, WithTradingCalendars, +class TestPositionPerformance(WithInstanceTmpDir, + WithTradingEnvironment, ZiplineTestCase): def create_environment_stuff(self, @@ -1054,6 +1054,7 @@ class TestPositionPerformance(WithInstanceTmpDir, WithTradingCalendars, self.env = self.enter_instance_context(tmp_trading_env( equities=equities, futures=futures, + load=self.make_load_function(), )) self.sim_params = create_simulation_parameters( start=start, diff --git a/tests/test_security_list.py b/tests/test_security_list.py index b381adc3..ac7b99fe 100644 --- a/tests/test_security_list.py +++ b/tests/test_security_list.py @@ -15,7 +15,7 @@ from zipline.testing import ( ) from zipline.testing.fixtures import ( WithLogger, - WithTradingCalendars, + WithTradingEnvironment, ZiplineTestCase, ) from zipline.utils import factory @@ -82,7 +82,9 @@ class IterateRLAlgo(TradingAlgorithm): self.found = True -class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase): +class SecurityListTestCase(WithLogger, + WithTradingEnvironment, + ZiplineTestCase): @classmethod def init_class_fixtures(cls): @@ -103,6 +105,7 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase): 'symbol': symbol, 'exchange': "TEST", } for symbol in symbols]), + load=cls.make_load_function(), )) cls.sim_params = factory.create_simulation_parameters( @@ -122,6 +125,7 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase): 'symbol': symbol, 'exchange': "TEST", } for symbol in symbols]), + load=cls.make_load_function(), )) cls.tempdir = cls.enter_class_context(tmp_dir()) @@ -304,7 +308,8 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase): }]) with TempDirectory() as new_tempdir, \ security_list_copy(), \ - tmp_trading_env(equities=equities) as env: + tmp_trading_env(equities=equities, + load=self.make_load_function()) as env: # add a delete statement removing bzq # write a new delete statement file to disk add_security_data([], ['BZQ']) diff --git a/tests/test_tradesimulation.py b/tests/test_tradesimulation.py index 6277b8e2..a7dc8e69 100644 --- a/tests/test_tradesimulation.py +++ b/tests/test_tradesimulation.py @@ -19,7 +19,6 @@ from mock import patch from nose_parameterized import parameterized from six.moves import range -from unittest import TestCase from zipline import TradingAlgorithm from zipline.gens.sim_engine import BEFORE_TRADING_START_BAR @@ -28,8 +27,12 @@ from zipline.finance.asset_restrictions import NoRestrictions from zipline.gens.tradesimulation import AlgorithmSimulator from zipline.sources.benchmark_source import BenchmarkSource from zipline.test_algorithms import NoopAlgorithm -from zipline.testing.fixtures import WithSimParams, ZiplineTestCase, \ - WithDataPortal +from zipline.testing.fixtures import ( + WithDataPortal, + WithSimParams, + WithTradingEnvironment, + ZiplineTestCase, +) from zipline.utils import factory from zipline.testing.core import FakeDataPortal from zipline.utils.calendars.trading_calendar import days_at_time @@ -50,7 +53,7 @@ class BeforeTradingAlgorithm(TradingAlgorithm): FREQUENCIES = {'daily': 0, 'minute': 1} # daily is less frequent than minute -class TestTradeSimulation(TestCase): +class TestTradeSimulation(WithTradingEnvironment, ZiplineTestCase): def fake_minutely_benchmark(self, dt): return 0.01 @@ -61,8 +64,8 @@ class TestTradeSimulation(TestCase): emission_rate='minute') with patch.object(BenchmarkSource, "get_value", self.fake_minutely_benchmark): - algo = NoopAlgorithm(sim_params=params) - algo.run(FakeDataPortal()) + algo = NoopAlgorithm(sim_params=params, env=self.env) + algo.run(FakeDataPortal(self.env)) self.assertEqual(len(algo.perf_tracker.sim_params.sessions), 1) @parameterized.expand([('%s_%s_%s' % (num_sessions, freq, emission_rate), @@ -82,8 +85,8 @@ class TestTradeSimulation(TestCase): with patch.object(BenchmarkSource, "get_value", self.fake_minutely_benchmark): - algo = BeforeTradingAlgorithm(sim_params=params) - algo.run(FakeDataPortal()) + algo = BeforeTradingAlgorithm(sim_params=params, env=self.env) + algo.run(FakeDataPortal(self.env)) self.assertEqual( len(algo.perf_tracker.sim_params.sessions), diff --git a/zipline/data/loader.py b/zipline/data/loader.py index 97b8c825..821d797f 100644 --- a/zipline/data/loader.py +++ b/zipline/data/loader.py @@ -53,13 +53,13 @@ def last_modified_time(path): return pd.Timestamp(os.path.getmtime(path), unit='s', tz='UTC') -def get_data_filepath(name): +def get_data_filepath(name, environ=None): """ Returns a handle to data file. Creates containing directory, if needed. """ - dr = data_root() + dr = data_root(environ) if not os.path.exists(dr): os.makedirs(dr) @@ -91,7 +91,8 @@ def has_data_for_dates(series_or_df, first_date, last_date): return (first <= first_date) and (last >= last_date) -def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC'): +def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC', + environ=None): """ Load benchmark returns and treasury yield curves for the given calendar and benchmark symbol. @@ -162,19 +163,22 @@ def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC'): # We need the trading_day to figure out the close prior to the first # date so that we can compute returns for the first date. trading_day, + environ, ) tc = ensure_treasury_data( bm_symbol, first_date, last_date, now, + environ, ) benchmark_returns = br[br.index.slice_indexer(first_date, last_date)] treasury_curves = tc[tc.index.slice_indexer(first_date, last_date)] return benchmark_returns, treasury_curves -def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day): +def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day, + environ=None): """ Ensure we have benchmark data for `symbol` from `first_date` to `last_date` @@ -204,7 +208,8 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day): path. """ filename = get_benchmark_filename(symbol) - data = _load_cached_data(filename, first_date, last_date, now, 'benchmark') + data = _load_cached_data(filename, first_date, last_date, now, 'benchmark', + environ) if data is not None: return data @@ -218,7 +223,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day): first_date - trading_day, last_date, ) - data.to_csv(get_data_filepath(filename)) + data.to_csv(get_data_filepath(filename, environ)) except (OSError, IOError, HTTPError): logger.exception('failed to cache the new benchmark returns') raise @@ -227,7 +232,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day): return data -def ensure_treasury_data(symbol, first_date, last_date, now): +def ensure_treasury_data(symbol, first_date, last_date, now, environ=None): """ Ensure we have treasury data from treasury module associated with `symbol`. @@ -259,7 +264,8 @@ def ensure_treasury_data(symbol, first_date, last_date, now): ) first_date = max(first_date, loader_module.earliest_possible_date()) - data = _load_cached_data(filename, first_date, last_date, now, 'treasury') + data = _load_cached_data(filename, first_date, last_date, now, 'treasury', + environ) if data is not None: return data @@ -269,7 +275,7 @@ def ensure_treasury_data(symbol, first_date, last_date, now): try: data = loader_module.get_treasury_data(first_date, last_date) - data.to_csv(get_data_filepath(filename)) + data.to_csv(get_data_filepath(filename, environ)) except (OSError, IOError, HTTPError): logger.exception('failed to cache treasury data') if not has_data_for_dates(data, first_date, last_date): @@ -277,14 +283,15 @@ def ensure_treasury_data(symbol, first_date, last_date, now): return data -def _load_cached_data(filename, first_date, last_date, now, resource_name): +def _load_cached_data(filename, first_date, last_date, now, resource_name, + environ=None): if resource_name == 'benchmark': from_csv = pd.Series.from_csv else: from_csv = pd.DataFrame.from_csv # Path for the cache. - path = get_data_filepath(filename) + path = get_data_filepath(filename, environ) # If the path does not exist, it means the first download has not happened # yet, so don't try to read from 'path'. diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index 78195d25..fb2aa069 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import logbook import pandas as pd @@ -86,11 +87,12 @@ class TradingEnvironment(object): trading_calendar=None, asset_db_path=':memory:', future_chain_predicates=CHAIN_PREDICATES, + environ=None, ): self.bm_symbol = bm_symbol if not load: - load = load_market_data + load = partial(load_market_data, environ=environ) if not trading_calendar: trading_calendar = get_calendar("NYSE") diff --git a/zipline/testing/__init__.py b/zipline/testing/__init__.py index 1a5449b0..3911aacf 100644 --- a/zipline/testing/__init__.py +++ b/zipline/testing/__init__.py @@ -16,6 +16,7 @@ from .core import ( # noqa check_allclose, check_arrays, chrange, + copy_market_data, create_daily_df_for_asset, create_data_portal, create_data_portal_from_trade_history, diff --git a/zipline/testing/core.py b/zipline/testing/core.py index 6afe1e4e..cbcb56e5 100644 --- a/zipline/testing/core.py +++ b/zipline/testing/core.py @@ -29,6 +29,7 @@ from toolz import concat, curry from zipline.assets import AssetFinder, AssetDBWriter from zipline.assets.synthetic import make_simple_equity_info from zipline.data.data_portal import DataPortal +from zipline.data.loader import get_benchmark_filename, INDEX_MAPPING from zipline.data.minute_bars import ( BcolzMinuteBarReader, BcolzMinuteBarWriter, @@ -52,6 +53,7 @@ from zipline.utils.calendars import get_calendar from zipline.utils.input_validation import expect_dimensions from zipline.utils.numpy_utils import as_column, isnat from zipline.utils.pandas_utils import timedelta_to_integral_seconds +from zipline.utils.paths import ensure_directory from zipline.utils.sentinel import sentinel import numpy as np @@ -695,11 +697,8 @@ def create_data_portal_from_trade_history(asset_finder, trading_calendar, class FakeDataPortal(DataPortal): - def __init__(self, env=None, trading_calendar=None, + def __init__(self, env, trading_calendar=None, first_trading_day=None): - if env is None: - env = TradingEnvironment() - if trading_calendar is None: trading_calendar = get_calendar("NYSE") @@ -862,6 +861,8 @@ class tmp_trading_env(tmp_asset_finder): Parameters ---------- + load : callable, optional + Function that returns benchmark returns and treasury curves. finder_cls : type, optional The type of asset finder to create from the assets db. **frames @@ -872,8 +873,13 @@ class tmp_trading_env(tmp_asset_finder): empty_trading_env tmp_asset_finder """ + def __init__(self, load=None, *args, **kwargs): + super(tmp_trading_env, self).__init__(*args, **kwargs) + self._load = load + def __enter__(self): return TradingEnvironment( + load=self._load, asset_db_path=super(tmp_trading_env, self).__enter__().engine, ) @@ -1486,6 +1492,19 @@ def patch_read_csv(url_map, module=pd, strict=False): yield +def copy_market_data(src_market_data_dir, dest_root_dir): + symbol = '^GSPC' + filenames = (get_benchmark_filename(symbol), INDEX_MAPPING[symbol][1]) + + ensure_directory(os.path.join(dest_root_dir, 'data')) + + for filename in filenames: + shutil.copyfile( + os.path.join(src_market_data_dir, filename), + os.path.join(dest_root_dir, 'data', filename) + ) + + @curry def ensure_doctest(f, name=None): """Ensure that an object gets doctested. This is useful for instances diff --git a/zipline/utils/run_algo.py b/zipline/utils/run_algo.py index 2d3b2e5b..639707fa 100644 --- a/zipline/utils/run_algo.py +++ b/zipline/utils/run_algo.py @@ -129,7 +129,7 @@ def _run(handle_data, "invalid url %r, must begin with 'sqlite:///'" % str(bundle_data.asset_finder.engine.url), ) - env = TradingEnvironment(asset_db_path=connstr) + env = TradingEnvironment(asset_db_path=connstr, environ=environ) first_trading_day =\ bundle_data.equity_minute_bar_reader.first_trading_day data = DataPortal( @@ -152,7 +152,7 @@ def _run(handle_data, "No PipelineLoader registered for column %s." % column ) else: - env = None + env = TradingEnvironment(environ=environ) choose_loader = None perf = TradingAlgorithm(