diff --git a/catalyst/exchange/exchange_data_portal.py b/catalyst/exchange/exchange_data_portal.py index 8f9665dc..c6523326 100644 --- a/catalyst/exchange/exchange_data_portal.py +++ b/catalyst/exchange/exchange_data_portal.py @@ -9,8 +9,9 @@ from catalyst.exchange.exchange_bundle import ExchangeBundle from catalyst.exchange.exchange_errors import ( ExchangeRequestError, PricingDataNotLoadedError) -from catalyst.exchange.utils.exchange_utils import resample_history_df, group_assets_by_exchange -from catalyst.exchange.utils.datetime_utils import get_frequency +from catalyst.exchange.utils.exchange_utils import resample_history_df, \ + group_assets_by_exchange +from catalyst.exchange.utils.datetime_utils import get_frequency, get_start_dt from logbook import Logger from redo import retry @@ -311,7 +312,8 @@ class DataPortalExchangeBacktest(DataPortalExchangeBase): algo_end_dt=self._last_available_session, ) - df = resample_history_df(pd.DataFrame(series), freq, field) + start_dt = get_start_dt(end_dt, adj_bar_count, data_frequency) + df = resample_history_df(pd.DataFrame(series), freq, field, start_dt) return df def get_exchange_spot_value(self, diff --git a/catalyst/exchange/utils/exchange_utils.py b/catalyst/exchange/utils/exchange_utils.py index 132845bb..ac48728e 100644 --- a/catalyst/exchange/utils/exchange_utils.py +++ b/catalyst/exchange/utils/exchange_utils.py @@ -126,8 +126,8 @@ def get_exchange_symbols(exchange_name, is_local=False, environ=None): filename = get_exchange_symbols_filename(exchange_name, is_local) if not is_local and (not os.path.isfile(filename) or pd.Timedelta( - pd.Timestamp('now', tz='UTC') - last_modified_time( - filename)).days > 1): + pd.Timestamp('now', tz='UTC') - last_modified_time( + filename)).days > 1): try: download_exchange_symbols(exchange_name, environ) except Exception as e: @@ -512,7 +512,7 @@ def get_common_assets(exchanges): return assets -def resample_history_df(df, freq, field): +def resample_history_df(df, freq, field, start_dt=None): """ Resample the OHCLV DataFrame using the specified frequency. @@ -541,6 +541,7 @@ def resample_history_df(df, freq, field): raise ValueError('Invalid field.') resampled_df = df.resample(freq, closed='left', label='left').agg(agg) + resampled_df = resampled_df[resampled_df.index >= start_dt] return resampled_df @@ -567,7 +568,7 @@ def mixin_market_params(exchange_name, params, market): params['taker'] = 0.002 elif 'maker' in market and 'taker' in market \ - and market['maker'] is not None and market['taker'] is not None: + and market['maker'] is not None and market['taker'] is not None: params['maker'] = market['maker'] params['taker'] = market['taker'] diff --git a/tests/exchange/test_suites/test_suite_bundle.py b/tests/exchange/test_suites/test_suite_bundle.py index 15d9cbcd..0bf8b1d2 100644 --- a/tests/exchange/test_suites/test_suite_bundle.py +++ b/tests/exchange/test_suites/test_suite_bundle.py @@ -37,7 +37,7 @@ class TestSuiteBundle: return data_portal def compare_bundle_with_exchange(self, exchange, assets, end_dt, bar_count, - freq, data_frequency, data_portal): + freq, data_frequency, data_portal, field): """ Creates DataFrames from the bundle and exchange for the specified data set. @@ -62,8 +62,8 @@ class TestSuiteBundle: with log_catcher: symbols = [asset.symbol for asset in assets] print( - 'comparing data for {}/{} with {} timeframe until {}'.format( - exchange.name, symbols, freq, end_dt + 'comparing {} for {}/{} with {} timeframe until {}'.format( + field, exchange.name, symbols, freq, end_dt ) ) data['bundle'] = data_portal.get_history_window( @@ -71,13 +71,13 @@ class TestSuiteBundle: end_dt=end_dt, bar_count=bar_count, frequency=freq, - field='close', + field=field, data_frequency=data_frequency, ) set_print_settings() print( - 'the bundle first / last row:\n{}'.format( - data['bundle'].iloc[[-1, 0]] + 'the bundle data:\n{}'.format( + data['bundle'] ) ) candles = exchange.get_candles( @@ -88,14 +88,14 @@ class TestSuiteBundle: ) data['exchange'] = get_candles_df( candles=candles, - field='close', + field=field, freq=freq, bar_count=bar_count, end_dt=end_dt, ) print( - 'the exchange first / last row:\n{}'.format( - data['exchange'].iloc[[-1, 0]] + 'the exchange data:\n{}'.format( + data['exchange'] ) ) for source in data: @@ -118,8 +118,10 @@ class TestSuiteBundle: check_less_precise=min([a.decimals for a in assets]), ) except Exception as e: - print('Some differences were found within a 1 decimal point ' - 'interval of confidence: {}'.format(e)) + print( + 'Some differences were found within a 1 decimal point ' + 'interval of confidence: {}'.format(e) + ) with open(os.path.join(folder, 'compare.txt'), 'w+') as handle: handle.write(e.args[0]) @@ -203,8 +205,11 @@ class TestSuiteBundle: frequencies = exchange.get_candle_frequencies(data_frequency) freq = random.sample(frequencies, 1)[0] + rnd = random.SystemRandom() + # field = rnd.choice(['open', 'high', 'low', 'close', 'volume']) + field = rnd.choice(['close']) - bar_count = random.randint(1, 10) + bar_count = random.randint(3, 6) assets = select_random_assets( exchange.assets, asset_population @@ -229,6 +234,7 @@ class TestSuiteBundle: freq=freq, data_frequency=data_frequency, data_portal=data_portal, + field=field, ) pass