mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 18:01:47 +08:00
TST: Expand Panel data test to test for multiple sids.
This commit is contained in:
+37
-20
@@ -4127,42 +4127,62 @@ class TestPanelData(ZiplineTestCase):
|
||||
pd.Timestamp('2015-12-24', tz='UTC'),),
|
||||
])
|
||||
def test_panel_data(self, data_frequency, start_dt, end_dt):
|
||||
trading_calendar = get_calendar('NYSE')
|
||||
if data_frequency == 'daily':
|
||||
history_freq = '1d'
|
||||
df = create_daily_df_for_asset(get_calendar('NYSE'),
|
||||
start_dt, end_dt)
|
||||
create_df_for_asset = create_daily_df_for_asset
|
||||
dt_transform = trading_calendar.minute_to_session_label
|
||||
elif data_frequency == 'minute':
|
||||
history_freq = '1m'
|
||||
df = create_minute_df_for_asset(get_calendar('NYSE'),
|
||||
start_dt, end_dt)
|
||||
create_df_for_asset = create_minute_df_for_asset
|
||||
|
||||
panel = pd.Panel({1: df})
|
||||
def dt_transform(dt):
|
||||
return dt
|
||||
|
||||
price_record = pd.DataFrame(columns=['current', 'previous'])
|
||||
sids = range(1, 3)
|
||||
dfs = {}
|
||||
for sid in sids:
|
||||
dfs[sid] = create_df_for_asset(trading_calendar,
|
||||
start_dt, end_dt, interval=sid)
|
||||
dfs[sid]['prev_close'] = dfs[sid]['close'].shift(1)
|
||||
panel = pd.Panel(dfs)
|
||||
|
||||
price_record = pd.Panel(items=sids,
|
||||
major_axis=panel.major_axis,
|
||||
minor_axis=['current', 'previous'])
|
||||
|
||||
def initialize(algo):
|
||||
algo.first_bar = True
|
||||
algo.equities = []
|
||||
for sid in sids:
|
||||
algo.equities.append(algo.sid(sid))
|
||||
|
||||
def handle_data(algo, data):
|
||||
price_record.loc[algo.get_datetime(), 'current'] = (
|
||||
data.current(algo.sid(1), 'price')
|
||||
price_record.loc[:, dt_transform(algo.get_datetime()),
|
||||
'current'] = (
|
||||
data.current(algo.equities, 'price')
|
||||
)
|
||||
if algo.first_bar:
|
||||
algo.first_bar = False
|
||||
else:
|
||||
price_record.loc[algo.get_datetime(), 'previous'] = (
|
||||
data.history(algo.sid(1), 'price', 2, history_freq)[0]
|
||||
price_record.loc[:, dt_transform(algo.get_datetime()),
|
||||
'previous'] = (
|
||||
data.history(algo.equities, 'price',
|
||||
2, history_freq).iloc[0]
|
||||
)
|
||||
|
||||
def check_panels():
|
||||
np.testing.assert_array_equal(
|
||||
price_record.values.astype('float64'),
|
||||
panel.loc[:, :, ['close',
|
||||
'prev_close']].values.astype('float64')
|
||||
)
|
||||
|
||||
trading_algo = TradingAlgorithm(initialize=initialize,
|
||||
handle_data=handle_data)
|
||||
trading_algo.run(data=panel)
|
||||
np.testing.assert_array_equal(
|
||||
np.array(price_record.transpose(), dtype='float64'),
|
||||
np.array([df['close'], df['close'].shift(1)], dtype='float64')
|
||||
)
|
||||
|
||||
price_record.drop(price_record.index)
|
||||
check_panels()
|
||||
price_record.loc[:] = np.nan
|
||||
|
||||
run_algorithm(
|
||||
start=start_dt,
|
||||
@@ -4173,7 +4193,4 @@ class TestPanelData(ZiplineTestCase):
|
||||
data_frequency=data_frequency,
|
||||
data=panel
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
np.array(price_record.transpose(), dtype='float64'),
|
||||
np.array([df['close'], df['close'].shift(1)], dtype='float64')
|
||||
)
|
||||
check_panels()
|
||||
|
||||
Reference in New Issue
Block a user