mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 17:29:56 +08:00
BLD: improve periods calculation
This commit is contained in:
@@ -16,7 +16,6 @@ from catalyst.exchange.utils.serialization_utils import ExchangeJSONEncoder, \
|
||||
ExchangeJSONDecoder
|
||||
from catalyst.utils.paths import data_root, ensure_directory, \
|
||||
last_modified_time
|
||||
from catalyst.exchange.utils.datetime_utils import get_periods_range
|
||||
|
||||
|
||||
def get_sid(symbol):
|
||||
@@ -740,19 +739,10 @@ def get_candles_df(candles, field, freq, bar_count, end_dt=None):
|
||||
|
||||
for asset in candles:
|
||||
asset_df = transform_candles_to_df(candles[asset])
|
||||
rounded_end_dt = end_dt.round(freq)
|
||||
|
||||
periods = get_periods_range(
|
||||
start_dt=None, end_dt=rounded_end_dt,
|
||||
freq=freq, periods=bar_count
|
||||
)
|
||||
|
||||
if rounded_end_dt > end_dt:
|
||||
periods = periods[:-1]
|
||||
elif rounded_end_dt <= end_dt:
|
||||
periods = periods[1:]
|
||||
|
||||
# periods = pd.date_range(end=end_dt, periods=bar_count, freq=freq)
|
||||
rounded_end_dt = end_dt.floor(freq)
|
||||
periods = pd.date_range(end=rounded_end_dt,
|
||||
periods=bar_count,
|
||||
freq=freq)
|
||||
asset_df = forward_fill_df_if_needed(asset_df, periods)
|
||||
|
||||
all_series[asset] = pd.Series(asset_df[field])
|
||||
|
||||
@@ -2,6 +2,7 @@ from catalyst.exchange.utils.exchange_utils import transform_candles_to_df, \
|
||||
forward_fill_df_if_needed, get_candles_df
|
||||
|
||||
from catalyst.testing.fixtures import WithLogger, ZiplineTestCase
|
||||
from datetime import timedelta
|
||||
from pandas import Timestamp, DataFrame, concat
|
||||
|
||||
import numpy as np
|
||||
@@ -15,9 +16,59 @@ class TestExchangeUtils(WithLogger, ZiplineTestCase):
|
||||
new_df.index.name = None
|
||||
return new_df
|
||||
|
||||
@classmethod
|
||||
def verify_forward_fill_df_if_needed(cls, candles, periods, expected_df):
|
||||
observed_df = forward_fill_df_if_needed(
|
||||
transform_candles_to_df(candles),
|
||||
periods)
|
||||
assert (expected_df.equals(observed_df))
|
||||
|
||||
@classmethod
|
||||
def verify_get_candles_df(cls, assets, candles, end_fixed_dt,
|
||||
expected_df, check_next_candle=False):
|
||||
# run on all the fields
|
||||
for field in ['volume', 'open', 'close', 'high', 'low']:
|
||||
|
||||
field_dt = cls.get_specific_field_from_df(expected_df,
|
||||
field,
|
||||
assets[0])
|
||||
# run on several timestamps
|
||||
for delta in range(5):
|
||||
end_dt = end_fixed_dt + timedelta(minutes=delta)
|
||||
assert (field_dt.equals(get_candles_df({assets[0]: candles},
|
||||
field, '5T', 3,
|
||||
end_dt=end_dt)))
|
||||
|
||||
field_dt_a1 = cls.get_specific_field_from_df(expected_df,
|
||||
field,
|
||||
assets[0])
|
||||
field_dt_a2 = cls.get_specific_field_from_df(expected_df,
|
||||
field,
|
||||
assets[1])
|
||||
observed_df = get_candles_df({assets[0]: candles,
|
||||
assets[1]: candles},
|
||||
field, '5T', 3,
|
||||
end_dt=end_dt)
|
||||
|
||||
assert (observed_df.equals(concat([field_dt_a1, field_dt_a2],
|
||||
axis=1)))
|
||||
|
||||
if check_next_candle:
|
||||
# one candle forward
|
||||
end_dt = end_fixed_dt + timedelta(minutes=6)
|
||||
observed_df = get_candles_df({assets[0]: candles,
|
||||
assets[1]: candles},
|
||||
field, '5T', 3,
|
||||
end_dt=end_dt)
|
||||
|
||||
assert (not observed_df.equals(concat([field_dt_a1,
|
||||
field_dt_a2],
|
||||
axis=1)))
|
||||
assert (concat([field_dt_a1, field_dt_a2],
|
||||
axis=1)[1:].equals(observed_df[:-1]))
|
||||
|
||||
def test_get_candles_df(self):
|
||||
asset = 'btc_usdt'
|
||||
asset2 = 'eth_usdt'
|
||||
assets = ['btc_usdt', 'eth_usdt']
|
||||
|
||||
# test forward fill in the end
|
||||
candles = [{'high': 595, 'volume': 10, 'low': 594,
|
||||
@@ -51,20 +102,12 @@ class TestExchangeUtils(WithLogger, ZiplineTestCase):
|
||||
Timestamp('2018-03-01 09:50:00+0000', tz='UTC'),
|
||||
Timestamp('2018-03-01 09:55:00+0000', tz='UTC')]
|
||||
|
||||
observed_df = forward_fill_df_if_needed(
|
||||
transform_candles_to_df(candles),
|
||||
periods)
|
||||
expected_df = transform_candles_to_df(expected)
|
||||
|
||||
assert (expected_df.equals(observed_df))
|
||||
|
||||
for field in ['volume', 'open', 'close', 'high', 'low']:
|
||||
field_dt = self.get_specific_field_from_df(expected_df,
|
||||
field,
|
||||
asset)
|
||||
assert (field_dt.equals(get_candles_df({asset: candles},
|
||||
field, '5T', 3,
|
||||
end_dt=periods[2])))
|
||||
self.verify_forward_fill_df_if_needed(candles, periods,
|
||||
expected_df)
|
||||
self.verify_get_candles_df(assets, candles, periods[2],
|
||||
expected_df, True)
|
||||
|
||||
# test forward fill in the middle
|
||||
candles = [{'high': 595, 'volume': 10, 'low': 594,
|
||||
@@ -94,28 +137,9 @@ class TestExchangeUtils(WithLogger, ZiplineTestCase):
|
||||
tz='UTC')
|
||||
}]
|
||||
|
||||
df = transform_candles_to_df(candles)
|
||||
observed_df = forward_fill_df_if_needed(df, periods)
|
||||
expected_df = transform_candles_to_df(expected)
|
||||
|
||||
assert (expected_df.equals(observed_df))
|
||||
|
||||
for field in ['volume', 'open', 'close', 'high', 'low']:
|
||||
# test several assets as well
|
||||
observed_df = get_candles_df({asset: candles,
|
||||
asset2: candles},
|
||||
field, '5T', 3,
|
||||
end_dt=periods[2])
|
||||
|
||||
field_dt_a1 = self.get_specific_field_from_df(expected_df,
|
||||
field,
|
||||
asset)
|
||||
field_dt_a2 = self.get_specific_field_from_df(expected_df,
|
||||
field,
|
||||
asset2)
|
||||
|
||||
assert(observed_df.equals(concat([field_dt_a1, field_dt_a2],
|
||||
axis=1)))
|
||||
self.verify_forward_fill_df_if_needed(candles, periods, expected_df)
|
||||
self.verify_get_candles_df(assets, candles, periods[2], expected_df)
|
||||
|
||||
# test "forward fill" at the beginning
|
||||
candles = [{'high': 595, 'volume': 10, 'low': 594,
|
||||
@@ -145,18 +169,7 @@ class TestExchangeUtils(WithLogger, ZiplineTestCase):
|
||||
tz='UTC')
|
||||
}]
|
||||
|
||||
df = transform_candles_to_df(candles)
|
||||
observed_df = forward_fill_df_if_needed(df, periods)
|
||||
expected_df = transform_candles_to_df(expected)
|
||||
|
||||
assert (expected_df.equals(observed_df))
|
||||
self.verify_forward_fill_df_if_needed(candles, periods, expected_df)
|
||||
# Not the same due to dropna - commenting out for now
|
||||
"""
|
||||
for field in ['volume', 'open', 'close', 'high', 'low']:
|
||||
field_dt = self.get_specific_field_from_df(observed_df,
|
||||
field,
|
||||
asset)
|
||||
assert(field_dt.equals(get_candles_df({asset:candles},
|
||||
field, '5T', 3,
|
||||
end_dt=periods[2])))
|
||||
"""
|
||||
# self.verify_get_candles_df(assets, candles, periods[2], expected_df)
|
||||
|
||||
Reference in New Issue
Block a user