From dc01c45dc4d9dc56bc23dfe3b08dbc78ad0e833b Mon Sep 17 00:00:00 2001 From: Jean Bredeche Date: Tue, 5 Apr 2016 11:08:04 -0400 Subject: [PATCH] DEV: Apply adjustments for portfolio and account in BTS completely copied from https://github.com/quantopian/zipline/pull/1104/ All credit goes to Andrew Liang (@lianga888) --- tests/test_algorithm.py | 168 +++++++++++++++++- tests/test_perf_tracking.py | 37 ++-- zipline/algorithm.py | 10 +- .../finance/performance/position_tracker.py | 27 ++- zipline/finance/performance/tracker.py | 14 +- zipline/gens/tradesimulation.py | 6 +- 6 files changed, 217 insertions(+), 45 deletions(-) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 9a81340b..ad834933 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -35,7 +35,11 @@ from zipline.api import FixedSlippage from zipline.data.data_portal import DataPortal from zipline.data.minute_bars import BcolzMinuteBarWriter, \ US_EQUITIES_MINUTES_PER_DAY, BcolzMinuteBarReader -from zipline.data.us_equity_pricing import BcolzDailyBarReader +from zipline.data.us_equity_pricing import ( + BcolzDailyBarReader, + SQLiteAdjustmentWriter, + SQLiteAdjustmentReader +) from zipline.finance.commission import PerShare from zipline.finance.execution import LimitOrder from zipline.finance.order import ORDER_STATUS @@ -47,7 +51,9 @@ from zipline.testing.core import ( create_data_portal, create_data_portal_from_trade_history, DailyBarWriterFromDataFrames, - create_daily_df_for_asset, write_minute_data_for_asset, + create_daily_df_for_asset, + write_minute_data_for_asset, + MockDailyBarReader, make_test_handler) from zipline.errors import ( OrderDuringInitialize, @@ -108,6 +114,7 @@ from zipline.testing import ( setup_logger, teardown_logger, parameter_space, + str_to_seconds ) from zipline.utils.api_support import ZiplineAPI, set_algo_instance from zipline.utils.context_tricks import CallbackManager @@ -1018,7 +1025,7 @@ class TestBeforeTradingStart(TestCase): ) equities_data = {} - for sid in [1, 2]: + for sid in [1, 2, 3]: equities_data[sid] = { "start_date": cls.trading_days[0], "end_date": cls.trading_days[-1], @@ -1029,6 +1036,7 @@ class TestBeforeTradingStart(TestCase): cls.asset1 = cls.env.asset_finder.retrieve_asset(1) cls.asset2 = cls.env.asset_finder.retrieve_asset(2) + cls.SPLIT_ASSET = cls.env.asset_finder.retrieve_asset(3) market_opens = cls.env.open_and_closes.market_open.loc[ cls.trading_days] @@ -1049,6 +1057,24 @@ class TestBeforeTradingStart(TestCase): cls.trading_days[-1], sid ) + # Write data with split asset + asset_minutes = cls.env.minutes_for_days_in_range( + cls.trading_days[0], cls.trading_days[-1]) + minutes_count = len(asset_minutes) + minutes_arr = np.array(range(1, 1 + minutes_count)) + + df = pd.DataFrame({ + "open": minutes_arr + 1, + "high": minutes_arr + 2, + "low": minutes_arr - 1, + "close": minutes_arr, + "volume": 100 * minutes_arr, + "dt": asset_minutes + }).set_index("dt") + df.iloc[780:] = df.iloc[780:] / 2.0 + + minute_writer.write(3, df) + # asset2 only trades every 50 minutes write_minute_data_for_asset( cls.env, minute_writer, cls.trading_days[0], @@ -1056,12 +1082,15 @@ class TestBeforeTradingStart(TestCase): ) cls.minute_reader = BcolzMinuteBarReader(cls.tempdir.path) + cls.adj_reader = cls.create_adjustments_reader() cls.daily_path = cls.tempdir.getpath("testdaily.bcolz") dfs = { 1: create_daily_df_for_asset(cls.env, cls.trading_days[0], cls.trading_days[-1]), 2: create_daily_df_for_asset(cls.env, cls.trading_days[0], + cls.trading_days[-1]), + 3: create_daily_df_for_asset(cls.env, cls.trading_days[0], cls.trading_days[-1]) } daily_writer = DailyBarWriterFromDataFrames(dfs) @@ -1077,9 +1106,45 @@ class TestBeforeTradingStart(TestCase): cls.data_portal = DataPortal( env=cls.env, equity_daily_reader=BcolzDailyBarReader(cls.daily_path), - equity_minute_reader=cls.minute_reader + equity_minute_reader=cls.minute_reader, + adjustment_reader=cls.adj_reader ) + @classmethod + def create_adjustments_reader(cls): + path = cls.tempdir.getpath("test_adjustments.db") + + adj_writer = SQLiteAdjustmentWriter( + path, + cls.env.trading_days, + MockDailyBarReader() + ) + + splits = pd.DataFrame([ + { + 'effective_date': str_to_seconds("2016-01-07"), + 'ratio': 0.5, + 'sid': cls.SPLIT_ASSET.sid + } + ]) + + # Mergers and Dividends are not tested, but we need to have these + # anyway + mergers = pd.DataFrame({}, columns=['effective_date', 'ratio', 'sid']) + mergers.effective_date = mergers.effective_date.astype(int) + mergers.ratio = mergers.ratio.astype(float) + mergers.sid = mergers.sid.astype(int) + + dividends = pd.DataFrame({}, columns=['ex_date', 'record_date', + 'declared_date', 'pay_date', + 'amount', 'sid']) + dividends.amount = dividends.amount.astype(float) + dividends.sid = dividends.sid.astype(int) + + adj_writer.write(splits, mergers, dividends) + + return SQLiteAdjustmentReader(path) + @classmethod def tearDownClass(cls): cls.tempdir.cleanup() @@ -1209,14 +1274,21 @@ class TestBeforeTradingStart(TestCase): def initialize(context): context.ordered = False + context.hd_portfolio = context.portfolio def before_trading_start(context, data): - record(pos_value=context.portfolio.positions_value) + bts_portfolio = context.portfolio + + # Assert that the portfolio in BTS is the same as the last + # portfolio in handle_data + assert (context.hd_portfolio == bts_portfolio) + record(pos_value=bts_portfolio.positions_value) def handle_data(context, data): if not context.ordered: order(sid(1), 1) context.ordered = True + context.hd_portfolio = context.portfolio """) algo = TradingAlgorithm( @@ -1241,14 +1313,21 @@ class TestBeforeTradingStart(TestCase): def initialize(context): context.ordered = False + context.hd_account = context.account def before_trading_start(context, data): + bts_account = context.account + + # Assert that the account in BTS is the same as the last account + # in handle_data + assert (context.hd_account == bts_account) record(port_value=context.account.equity_with_loan) def handle_data(context, data): if not context.ordered: order(sid(1), 1) context.ordered = True + context.hd_acount = context.account """) algo = TradingAlgorithm( @@ -1268,6 +1347,85 @@ class TestBeforeTradingStart(TestCase): self.assertAlmostEqual(results.port_value.iloc[1], 10000 + 780 - 392 - 1) + def test_portfolio_bts_with_overnight_split(self): + algo_code = dedent(""" + from zipline.api import order, sid, record + def initialize(context): + context.ordered = False + context.hd_portfolio = context.portfolio + def before_trading_start(context, data): + bts_portfolio = context.portfolio + # Assert that the portfolio in BTS is the same as the last + # portfolio in handle_data, except for the positions + for k in bts_portfolio.__dict__: + if k != 'positions': + assert (context.hd_portfolio.__dict__[k] + == bts_portfolio.__dict__[k]) + record(pos_value=bts_portfolio.positions_value) + record(pos_amount=bts_portfolio.positions[sid(3)]['amount']) + record(last_sale_price=bts_portfolio.positions[sid(3)] + ['last_sale_price']) + def handle_data(context, data): + if not context.ordered: + order(sid(3), 1) + context.ordered = True + context.hd_portfolio = context.portfolio + """) + + algo = TradingAlgorithm( + script=algo_code, + data_frequency="minute", + sim_params=self.sim_params, + env=self.env + ) + + results = algo.run(self.data_portal) + + # On 1/07, positions value should by 780, same as without split + self.assertEqual(results.pos_value.iloc[0], 0) + self.assertEqual(results.pos_value.iloc[1], 780) + + # On 1/07, after applying the split, 1 share becomes 2 + self.assertEqual(results.pos_amount.iloc[0], 0) + self.assertEqual(results.pos_amount.iloc[1], 2) + + # On 1/07, after applying the split, last sale price is halved + self.assertEqual(results.last_sale_price.iloc[0], 0) + self.assertEqual(results.last_sale_price.iloc[1], 390) + + def test_account_bts_with_overnight_split(self): + algo_code = dedent(""" + from zipline.api import order, sid, record + def initialize(context): + context.ordered = False + context.hd_account = context.account + def before_trading_start(context, data): + bts_account = context.account + # Assert that the account in BTS is the same as the last account + # in handle_data + assert (context.hd_account == bts_account) + record(port_value=bts_account.equity_with_loan) + def handle_data(context, data): + if not context.ordered: + order(sid(1), 1) + context.ordered = True + context.hd_account = context.account + """) + + algo = TradingAlgorithm( + script=algo_code, + data_frequency="minute", + sim_params=self.sim_params, + env=self.env + ) + + results = algo.run(self.data_portal) + + # On 1/07, portfolio value is the same as without split + self.assertEqual(results.port_value.iloc[0], 10000) + self.assertAlmostEqual(results.port_value.iloc[1], + 10000 + 780 - 392 - 1) + class TestAlgoScript(TestCase): diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index 37e277f7..11b3886e 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -190,7 +190,8 @@ def calculate_results(sim_params, pass msg = perf_tracker.handle_market_close_daily(date) - msg['account'] = perf_tracker.get_account(True, date) + perf_tracker.position_tracker.sync_last_sale_prices(date, False) + msg['account'] = perf_tracker.get_account(True) results.append(copy.deepcopy(msg)) return results @@ -1251,7 +1252,7 @@ class TestPositionPerformance(unittest.TestCase): pp.handle_execution(txn2) dt = trades_1[-2].dt - pt.sync_last_sale_prices(dt) + pt.sync_last_sale_prices(dt, False) pp.calculate_performance() @@ -1279,7 +1280,7 @@ class TestPositionPerformance(unittest.TestCase): net_liquidation=1000.0) dt = trades_1[-1].dt - pt.sync_last_sale_prices(dt) + pt.sync_last_sale_prices(dt, False) pp.calculate_performance() @@ -1354,7 +1355,7 @@ class TestPositionPerformance(unittest.TestCase): shorts_count=0) # Validate that the account attributes were updated. - pt.sync_last_sale_prices(trades[-2].dt) + pt.sync_last_sale_prices(trades[-2].dt, False) # Validate that the account attributes were updated. account = pp.as_account() @@ -1372,7 +1373,7 @@ class TestPositionPerformance(unittest.TestCase): net_liquidation=1000.0) # now simulate a price jump to $11 - pt.sync_last_sale_prices(trades[-1].dt) + pt.sync_last_sale_prices(trades[-1].dt, False) pp.calculate_performance() @@ -1443,7 +1444,7 @@ class TestPositionPerformance(unittest.TestCase): # stocks with a last sale price of 0. self.assertEqual(pp.positions[1].last_sale_price, 10.0) - pt.sync_last_sale_prices(trades[-1].dt) + pt.sync_last_sale_prices(trades[-1].dt, False) pp.calculate_performance() @@ -1554,7 +1555,7 @@ single short-sale transaction""" pt.execute_transaction(txn) pp.handle_execution(txn) - pt.sync_last_sale_prices(trades_1[-1].dt) + pt.sync_last_sale_prices(trades_1[-1].dt, False) pp.calculate_performance() @@ -1610,7 +1611,7 @@ single short-sale transaction""" # simulate a rollover to a new period pp.rollover() - pt.sync_last_sale_prices(trades[-1].dt) + pt.sync_last_sale_prices(trades[-1].dt, False) pp.calculate_performance() @@ -1674,7 +1675,7 @@ single short-sale transaction""" ptTotal.execute_transaction(txn) ppTotal.handle_execution(txn) - ptTotal.sync_last_sale_prices(trades[-1].dt) + ptTotal.sync_last_sale_prices(trades[-1].dt, False) ppTotal.calculate_performance() @@ -1794,7 +1795,7 @@ cost of sole txn in test" # stocks with a last sale price of 0. self.assertEqual(pp.positions[3].last_sale_price, 10.0) - pt.sync_last_sale_prices(trades[-1].dt) + pt.sync_last_sale_prices(trades[-1].dt, False) pp.calculate_performance() self.assertEqual( @@ -1908,7 +1909,7 @@ single short-sale transaction""" pt.execute_transaction(txn) pp.handle_execution(txn) - pt.sync_last_sale_prices(trades[-3].dt) + pt.sync_last_sale_prices(trades[-3].dt, False) pp.calculate_performance() self.assertEqual( @@ -1968,7 +1969,7 @@ single short-sale transaction""" # simulate a rollover to a new period pp.rollover() - pt.sync_last_sale_prices(trades_2[-1].dt) + pt.sync_last_sale_prices(trades_2[-1].dt, False) pp.calculate_performance() self.assertEqual( @@ -2034,13 +2035,13 @@ single short-sale transaction""" ppTotal.position_tracker = ptTotal for trade in trades_1: - ptTotal.sync_last_sale_prices(trade.dt) + ptTotal.sync_last_sale_prices(trade.dt, False) ptTotal.execute_transaction(txn) ppTotal.handle_execution(txn) for trade in trades_2: - ptTotal.sync_last_sale_prices(trade.dt) + ptTotal.sync_last_sale_prices(trade.dt, False) ppTotal.calculate_performance() @@ -2155,7 +2156,7 @@ trade after cover""" pt.execute_transaction(cover_txn) pp.handle_execution(cover_txn) - pt.sync_last_sale_prices(trades[-1].dt) + pt.sync_last_sale_prices(trades[-1].dt, False) pp.calculate_performance() @@ -2263,7 +2264,7 @@ shares in position" "should have a cost basis of 11" ) - pt.sync_last_sale_prices(dt) + pt.sync_last_sale_prices(dt, False) pp.calculate_performance() @@ -2280,7 +2281,7 @@ shares in position" pp.handle_execution(sale_txn) dt = down_tick.dt - pt.sync_last_sale_prices(dt) + pt.sync_last_sale_prices(dt, False) pp.calculate_performance() self.assertEqual( @@ -2316,7 +2317,7 @@ shares in position" pp3.handle_execution(sale_txn) trades.append(down_tick) - pt3.sync_last_sale_prices(trades[-1].dt) + pt3.sync_last_sale_prices(trades[-1].dt, False) pp3.calculate_performance() self.assertEqual( diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 7856c8fd..8e204710 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -1056,9 +1056,10 @@ class TradingAlgorithm(object): def updated_portfolio(self): if self.portfolio_needs_update: + self.perf_tracker.position_tracker.sync_last_sale_prices( + self.datetime, self._in_before_trading_start) self._portfolio = \ - self.perf_tracker.get_portfolio(self.performance_needs_update, - self.datetime) + self.perf_tracker.get_portfolio(self.performance_needs_update) self.portfolio_needs_update = False self.performance_needs_update = False return self._portfolio @@ -1069,9 +1070,10 @@ class TradingAlgorithm(object): def updated_account(self): if self.account_needs_update: + self.perf_tracker.position_tracker.sync_last_sale_prices( + self.datetime, self._in_before_trading_start) self._account = \ - self.perf_tracker.get_account(self.performance_needs_update, - self.datetime) + self.perf_tracker.get_account(self.performance_needs_update) self.account_needs_update = False self.performance_needs_update = False return self._account diff --git a/zipline/finance/performance/position_tracker.py b/zipline/finance/performance/position_tracker.py index 0c156ff4..88d68dc5 100644 --- a/zipline/finance/performance/position_tracker.py +++ b/zipline/finance/performance/position_tracker.py @@ -372,15 +372,28 @@ class PositionTracker(object): positions.append(pos.to_dict()) return positions - def sync_last_sale_prices(self, dt): + def sync_last_sale_prices(self, dt, handle_non_market_minutes): data_portal = self._data_portal - for asset, position in iteritems(self.positions): - last_sale_price = data_portal.get_spot_value( - asset, 'price', dt, self.data_frequency - ) + if not handle_non_market_minutes: + for asset, position in iteritems(self.positions): + last_sale_price = data_portal.get_spot_value( + asset, 'price', dt, self.data_frequency + ) - if not np.isnan(last_sale_price): - position.last_sale_price = last_sale_price + if not np.isnan(last_sale_price): + position.last_sale_price = last_sale_price + else: + for asset, position in iteritems(self.positions): + last_sale_price = data_portal.get_adjusted_value( + asset, + 'price', + data_portal.env.previous_market_minute(dt), + dt, + self.data_frequency + ) + + if not np.isnan(last_sale_price): + position.last_sale_price = last_sale_price def stats(self): amounts = [] diff --git a/zipline/finance/performance/tracker.py b/zipline/finance/performance/tracker.py index d3a33cf5..6307416b 100644 --- a/zipline/finance/performance/tracker.py +++ b/zipline/finance/performance/tracker.py @@ -190,9 +190,8 @@ class PerformanceTracker(object): self.saved_dt = date self.todays_performance.period_close = self.saved_dt - def get_portfolio(self, performance_needs_update, dt): + def get_portfolio(self, performance_needs_update): if performance_needs_update: - self.position_tracker.sync_last_sale_prices(dt) self.update_performance() self.account_needs_update = True return self.cumulative_performance.as_portfolio() @@ -202,9 +201,8 @@ class PerformanceTracker(object): self.cumulative_performance.calculate_performance() self.todays_performance.calculate_performance() - def get_account(self, performance_needs_update, dt): + def get_account(self, performance_needs_update): if performance_needs_update: - self.position_tracker.sync_last_sale_prices(dt) self.update_performance() self.account_needs_update = True if self.account_needs_update: @@ -329,10 +327,10 @@ class PerformanceTracker(object): A tuple of the minute perf packet and daily perf packet. If the market day has not ended, the daily perf packet is None. """ - self.position_tracker.sync_last_sale_prices(dt) + self.position_tracker.sync_last_sale_prices(dt, False) self.update_performance() todays_date = normalize_date(dt) - account = self.get_account(False, dt) + account = self.get_account(False) bench_returns = self.all_benchmark_returns.loc[todays_date:dt] # cumulative returns @@ -357,10 +355,10 @@ class PerformanceTracker(object): Function called after handle_data when running with daily emission rate. """ - self.position_tracker.sync_last_sale_prices(dt) + self.position_tracker.sync_last_sale_prices(dt, False) self.update_performance() completed_date = self.day - account = self.get_account(False, dt) + account = self.get_account(False) benchmark_value = self.all_benchmark_returns[completed_date] diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index 5b7c92fe..b9200a5b 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -166,9 +166,6 @@ class AlgorithmSimulator(object): # before cleaning up expired assets. self._cleanup_expired_assets(midnight_dt, position_assets) - # call before trading start - algo.before_trading_start(current_data) - perf_tracker = algo.perf_tracker # handle any splits that impact any positions or any open orders. @@ -183,6 +180,9 @@ class AlgorithmSimulator(object): algo.blotter.process_splits(splits) perf_tracker.position_tracker.handle_splits(splits) + # call before trading start + algo.before_trading_start(current_data) + def handle_benchmark(date): algo.perf_tracker.all_benchmark_returns[date] = \ self.benchmark_source.get_value(date)