From e59f46dfd653220a77fdf1ea1698e9137a14d2b9 Mon Sep 17 00:00:00 2001 From: lenak25 Date: Sun, 4 Mar 2018 13:44:04 +0200 Subject: [PATCH] BLD: improve periods calculation --- catalyst/exchange/utils/exchange_utils.py | 18 +--- tests/exchange/test_exchange_utils.py | 109 ++++++++++++---------- 2 files changed, 65 insertions(+), 62 deletions(-) diff --git a/catalyst/exchange/utils/exchange_utils.py b/catalyst/exchange/utils/exchange_utils.py index 9a681068..28185091 100644 --- a/catalyst/exchange/utils/exchange_utils.py +++ b/catalyst/exchange/utils/exchange_utils.py @@ -16,7 +16,6 @@ from catalyst.exchange.utils.serialization_utils import ExchangeJSONEncoder, \ ExchangeJSONDecoder from catalyst.utils.paths import data_root, ensure_directory, \ last_modified_time -from catalyst.exchange.utils.datetime_utils import get_periods_range def get_sid(symbol): @@ -740,19 +739,10 @@ def get_candles_df(candles, field, freq, bar_count, end_dt=None): for asset in candles: asset_df = transform_candles_to_df(candles[asset]) - rounded_end_dt = end_dt.round(freq) - - periods = get_periods_range( - start_dt=None, end_dt=rounded_end_dt, - freq=freq, periods=bar_count - ) - - if rounded_end_dt > end_dt: - periods = periods[:-1] - elif rounded_end_dt <= end_dt: - periods = periods[1:] - - # periods = pd.date_range(end=end_dt, periods=bar_count, freq=freq) + rounded_end_dt = end_dt.floor(freq) + periods = pd.date_range(end=rounded_end_dt, + periods=bar_count, + freq=freq) asset_df = forward_fill_df_if_needed(asset_df, periods) all_series[asset] = pd.Series(asset_df[field]) diff --git a/tests/exchange/test_exchange_utils.py b/tests/exchange/test_exchange_utils.py index ddc15cc9..2d3d1efe 100644 --- a/tests/exchange/test_exchange_utils.py +++ b/tests/exchange/test_exchange_utils.py @@ -2,6 +2,7 @@ from catalyst.exchange.utils.exchange_utils import transform_candles_to_df, \ forward_fill_df_if_needed, get_candles_df from catalyst.testing.fixtures import WithLogger, ZiplineTestCase +from datetime import timedelta from pandas import Timestamp, DataFrame, concat import numpy as np @@ -15,9 +16,59 @@ class TestExchangeUtils(WithLogger, ZiplineTestCase): new_df.index.name = None return new_df + @classmethod + def verify_forward_fill_df_if_needed(cls, candles, periods, expected_df): + observed_df = forward_fill_df_if_needed( + transform_candles_to_df(candles), + periods) + assert (expected_df.equals(observed_df)) + + @classmethod + def verify_get_candles_df(cls, assets, candles, end_fixed_dt, + expected_df, check_next_candle=False): + # run on all the fields + for field in ['volume', 'open', 'close', 'high', 'low']: + + field_dt = cls.get_specific_field_from_df(expected_df, + field, + assets[0]) + # run on several timestamps + for delta in range(5): + end_dt = end_fixed_dt + timedelta(minutes=delta) + assert (field_dt.equals(get_candles_df({assets[0]: candles}, + field, '5T', 3, + end_dt=end_dt))) + + field_dt_a1 = cls.get_specific_field_from_df(expected_df, + field, + assets[0]) + field_dt_a2 = cls.get_specific_field_from_df(expected_df, + field, + assets[1]) + observed_df = get_candles_df({assets[0]: candles, + assets[1]: candles}, + field, '5T', 3, + end_dt=end_dt) + + assert (observed_df.equals(concat([field_dt_a1, field_dt_a2], + axis=1))) + + if check_next_candle: + # one candle forward + end_dt = end_fixed_dt + timedelta(minutes=6) + observed_df = get_candles_df({assets[0]: candles, + assets[1]: candles}, + field, '5T', 3, + end_dt=end_dt) + + assert (not observed_df.equals(concat([field_dt_a1, + field_dt_a2], + axis=1))) + assert (concat([field_dt_a1, field_dt_a2], + axis=1)[1:].equals(observed_df[:-1])) + def test_get_candles_df(self): - asset = 'btc_usdt' - asset2 = 'eth_usdt' + assets = ['btc_usdt', 'eth_usdt'] # test forward fill in the end candles = [{'high': 595, 'volume': 10, 'low': 594, @@ -51,20 +102,12 @@ class TestExchangeUtils(WithLogger, ZiplineTestCase): Timestamp('2018-03-01 09:50:00+0000', tz='UTC'), Timestamp('2018-03-01 09:55:00+0000', tz='UTC')] - observed_df = forward_fill_df_if_needed( - transform_candles_to_df(candles), - periods) expected_df = transform_candles_to_df(expected) - assert (expected_df.equals(observed_df)) - - for field in ['volume', 'open', 'close', 'high', 'low']: - field_dt = self.get_specific_field_from_df(expected_df, - field, - asset) - assert (field_dt.equals(get_candles_df({asset: candles}, - field, '5T', 3, - end_dt=periods[2]))) + self.verify_forward_fill_df_if_needed(candles, periods, + expected_df) + self.verify_get_candles_df(assets, candles, periods[2], + expected_df, True) # test forward fill in the middle candles = [{'high': 595, 'volume': 10, 'low': 594, @@ -94,28 +137,9 @@ class TestExchangeUtils(WithLogger, ZiplineTestCase): tz='UTC') }] - df = transform_candles_to_df(candles) - observed_df = forward_fill_df_if_needed(df, periods) expected_df = transform_candles_to_df(expected) - - assert (expected_df.equals(observed_df)) - - for field in ['volume', 'open', 'close', 'high', 'low']: - # test several assets as well - observed_df = get_candles_df({asset: candles, - asset2: candles}, - field, '5T', 3, - end_dt=periods[2]) - - field_dt_a1 = self.get_specific_field_from_df(expected_df, - field, - asset) - field_dt_a2 = self.get_specific_field_from_df(expected_df, - field, - asset2) - - assert(observed_df.equals(concat([field_dt_a1, field_dt_a2], - axis=1))) + self.verify_forward_fill_df_if_needed(candles, periods, expected_df) + self.verify_get_candles_df(assets, candles, periods[2], expected_df) # test "forward fill" at the beginning candles = [{'high': 595, 'volume': 10, 'low': 594, @@ -145,18 +169,7 @@ class TestExchangeUtils(WithLogger, ZiplineTestCase): tz='UTC') }] - df = transform_candles_to_df(candles) - observed_df = forward_fill_df_if_needed(df, periods) expected_df = transform_candles_to_df(expected) - - assert (expected_df.equals(observed_df)) + self.verify_forward_fill_df_if_needed(candles, periods, expected_df) # Not the same due to dropna - commenting out for now - """ - for field in ['volume', 'open', 'close', 'high', 'low']: - field_dt = self.get_specific_field_from_df(observed_df, - field, - asset) - assert(field_dt.equals(get_candles_df({asset:candles}, - field, '5T', 3, - end_dt=periods[2]))) - """ + # self.verify_get_candles_df(assets, candles, periods[2], expected_df)