diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index fcc67017..3140caf3 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -4127,42 +4127,62 @@ class TestPanelData(ZiplineTestCase): pd.Timestamp('2015-12-24', tz='UTC'),), ]) def test_panel_data(self, data_frequency, start_dt, end_dt): + trading_calendar = get_calendar('NYSE') if data_frequency == 'daily': history_freq = '1d' - df = create_daily_df_for_asset(get_calendar('NYSE'), - start_dt, end_dt) + create_df_for_asset = create_daily_df_for_asset + dt_transform = trading_calendar.minute_to_session_label elif data_frequency == 'minute': history_freq = '1m' - df = create_minute_df_for_asset(get_calendar('NYSE'), - start_dt, end_dt) + create_df_for_asset = create_minute_df_for_asset - panel = pd.Panel({1: df}) + def dt_transform(dt): + return dt - price_record = pd.DataFrame(columns=['current', 'previous']) + sids = range(1, 3) + dfs = {} + for sid in sids: + dfs[sid] = create_df_for_asset(trading_calendar, + start_dt, end_dt, interval=sid) + dfs[sid]['prev_close'] = dfs[sid]['close'].shift(1) + panel = pd.Panel(dfs) + + price_record = pd.Panel(items=sids, + major_axis=panel.major_axis, + minor_axis=['current', 'previous']) def initialize(algo): algo.first_bar = True + algo.equities = [] + for sid in sids: + algo.equities.append(algo.sid(sid)) def handle_data(algo, data): - price_record.loc[algo.get_datetime(), 'current'] = ( - data.current(algo.sid(1), 'price') + price_record.loc[:, dt_transform(algo.get_datetime()), + 'current'] = ( + data.current(algo.equities, 'price') ) if algo.first_bar: algo.first_bar = False else: - price_record.loc[algo.get_datetime(), 'previous'] = ( - data.history(algo.sid(1), 'price', 2, history_freq)[0] + price_record.loc[:, dt_transform(algo.get_datetime()), + 'previous'] = ( + data.history(algo.equities, 'price', + 2, history_freq).iloc[0] ) + def check_panels(): + np.testing.assert_array_equal( + price_record.values.astype('float64'), + panel.loc[:, :, ['close', + 'prev_close']].values.astype('float64') + ) + trading_algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data) trading_algo.run(data=panel) - np.testing.assert_array_equal( - np.array(price_record.transpose(), dtype='float64'), - np.array([df['close'], df['close'].shift(1)], dtype='float64') - ) - - price_record.drop(price_record.index) + check_panels() + price_record.loc[:] = np.nan run_algorithm( start=start_dt, @@ -4173,7 +4193,4 @@ class TestPanelData(ZiplineTestCase): data_frequency=data_frequency, data=panel ) - np.testing.assert_array_equal( - np.array(price_record.transpose(), dtype='float64'), - np.array([df['close'], df['close'].shift(1)], dtype='float64') - ) + check_panels()