DEV: Apply adjustments for portfolio and account in BTS

completely copied from https://github.com/quantopian/zipline/pull/1104/

All credit goes to Andrew Liang (@lianga888)
This commit is contained in:
Jean Bredeche
2016-04-05 11:08:04 -04:00
parent 34f47da033
commit dc01c45dc4
6 changed files with 217 additions and 45 deletions
+163 -5
View File
@@ -35,7 +35,11 @@ from zipline.api import FixedSlippage
from zipline.data.data_portal import DataPortal
from zipline.data.minute_bars import BcolzMinuteBarWriter, \
US_EQUITIES_MINUTES_PER_DAY, BcolzMinuteBarReader
from zipline.data.us_equity_pricing import BcolzDailyBarReader
from zipline.data.us_equity_pricing import (
BcolzDailyBarReader,
SQLiteAdjustmentWriter,
SQLiteAdjustmentReader
)
from zipline.finance.commission import PerShare
from zipline.finance.execution import LimitOrder
from zipline.finance.order import ORDER_STATUS
@@ -47,7 +51,9 @@ from zipline.testing.core import (
create_data_portal,
create_data_portal_from_trade_history,
DailyBarWriterFromDataFrames,
create_daily_df_for_asset, write_minute_data_for_asset,
create_daily_df_for_asset,
write_minute_data_for_asset,
MockDailyBarReader,
make_test_handler)
from zipline.errors import (
OrderDuringInitialize,
@@ -108,6 +114,7 @@ from zipline.testing import (
setup_logger,
teardown_logger,
parameter_space,
str_to_seconds
)
from zipline.utils.api_support import ZiplineAPI, set_algo_instance
from zipline.utils.context_tricks import CallbackManager
@@ -1018,7 +1025,7 @@ class TestBeforeTradingStart(TestCase):
)
equities_data = {}
for sid in [1, 2]:
for sid in [1, 2, 3]:
equities_data[sid] = {
"start_date": cls.trading_days[0],
"end_date": cls.trading_days[-1],
@@ -1029,6 +1036,7 @@ class TestBeforeTradingStart(TestCase):
cls.asset1 = cls.env.asset_finder.retrieve_asset(1)
cls.asset2 = cls.env.asset_finder.retrieve_asset(2)
cls.SPLIT_ASSET = cls.env.asset_finder.retrieve_asset(3)
market_opens = cls.env.open_and_closes.market_open.loc[
cls.trading_days]
@@ -1049,6 +1057,24 @@ class TestBeforeTradingStart(TestCase):
cls.trading_days[-1], sid
)
# Write data with split asset
asset_minutes = cls.env.minutes_for_days_in_range(
cls.trading_days[0], cls.trading_days[-1])
minutes_count = len(asset_minutes)
minutes_arr = np.array(range(1, 1 + minutes_count))
df = pd.DataFrame({
"open": minutes_arr + 1,
"high": minutes_arr + 2,
"low": minutes_arr - 1,
"close": minutes_arr,
"volume": 100 * minutes_arr,
"dt": asset_minutes
}).set_index("dt")
df.iloc[780:] = df.iloc[780:] / 2.0
minute_writer.write(3, df)
# asset2 only trades every 50 minutes
write_minute_data_for_asset(
cls.env, minute_writer, cls.trading_days[0],
@@ -1056,12 +1082,15 @@ class TestBeforeTradingStart(TestCase):
)
cls.minute_reader = BcolzMinuteBarReader(cls.tempdir.path)
cls.adj_reader = cls.create_adjustments_reader()
cls.daily_path = cls.tempdir.getpath("testdaily.bcolz")
dfs = {
1: create_daily_df_for_asset(cls.env, cls.trading_days[0],
cls.trading_days[-1]),
2: create_daily_df_for_asset(cls.env, cls.trading_days[0],
cls.trading_days[-1]),
3: create_daily_df_for_asset(cls.env, cls.trading_days[0],
cls.trading_days[-1])
}
daily_writer = DailyBarWriterFromDataFrames(dfs)
@@ -1077,9 +1106,45 @@ class TestBeforeTradingStart(TestCase):
cls.data_portal = DataPortal(
env=cls.env,
equity_daily_reader=BcolzDailyBarReader(cls.daily_path),
equity_minute_reader=cls.minute_reader
equity_minute_reader=cls.minute_reader,
adjustment_reader=cls.adj_reader
)
@classmethod
def create_adjustments_reader(cls):
path = cls.tempdir.getpath("test_adjustments.db")
adj_writer = SQLiteAdjustmentWriter(
path,
cls.env.trading_days,
MockDailyBarReader()
)
splits = pd.DataFrame([
{
'effective_date': str_to_seconds("2016-01-07"),
'ratio': 0.5,
'sid': cls.SPLIT_ASSET.sid
}
])
# Mergers and Dividends are not tested, but we need to have these
# anyway
mergers = pd.DataFrame({}, columns=['effective_date', 'ratio', 'sid'])
mergers.effective_date = mergers.effective_date.astype(int)
mergers.ratio = mergers.ratio.astype(float)
mergers.sid = mergers.sid.astype(int)
dividends = pd.DataFrame({}, columns=['ex_date', 'record_date',
'declared_date', 'pay_date',
'amount', 'sid'])
dividends.amount = dividends.amount.astype(float)
dividends.sid = dividends.sid.astype(int)
adj_writer.write(splits, mergers, dividends)
return SQLiteAdjustmentReader(path)
@classmethod
def tearDownClass(cls):
cls.tempdir.cleanup()
@@ -1209,14 +1274,21 @@ class TestBeforeTradingStart(TestCase):
def initialize(context):
context.ordered = False
context.hd_portfolio = context.portfolio
def before_trading_start(context, data):
record(pos_value=context.portfolio.positions_value)
bts_portfolio = context.portfolio
# Assert that the portfolio in BTS is the same as the last
# portfolio in handle_data
assert (context.hd_portfolio == bts_portfolio)
record(pos_value=bts_portfolio.positions_value)
def handle_data(context, data):
if not context.ordered:
order(sid(1), 1)
context.ordered = True
context.hd_portfolio = context.portfolio
""")
algo = TradingAlgorithm(
@@ -1241,14 +1313,21 @@ class TestBeforeTradingStart(TestCase):
def initialize(context):
context.ordered = False
context.hd_account = context.account
def before_trading_start(context, data):
bts_account = context.account
# Assert that the account in BTS is the same as the last account
# in handle_data
assert (context.hd_account == bts_account)
record(port_value=context.account.equity_with_loan)
def handle_data(context, data):
if not context.ordered:
order(sid(1), 1)
context.ordered = True
context.hd_acount = context.account
""")
algo = TradingAlgorithm(
@@ -1268,6 +1347,85 @@ class TestBeforeTradingStart(TestCase):
self.assertAlmostEqual(results.port_value.iloc[1],
10000 + 780 - 392 - 1)
def test_portfolio_bts_with_overnight_split(self):
algo_code = dedent("""
from zipline.api import order, sid, record
def initialize(context):
context.ordered = False
context.hd_portfolio = context.portfolio
def before_trading_start(context, data):
bts_portfolio = context.portfolio
# Assert that the portfolio in BTS is the same as the last
# portfolio in handle_data, except for the positions
for k in bts_portfolio.__dict__:
if k != 'positions':
assert (context.hd_portfolio.__dict__[k]
== bts_portfolio.__dict__[k])
record(pos_value=bts_portfolio.positions_value)
record(pos_amount=bts_portfolio.positions[sid(3)]['amount'])
record(last_sale_price=bts_portfolio.positions[sid(3)]
['last_sale_price'])
def handle_data(context, data):
if not context.ordered:
order(sid(3), 1)
context.ordered = True
context.hd_portfolio = context.portfolio
""")
algo = TradingAlgorithm(
script=algo_code,
data_frequency="minute",
sim_params=self.sim_params,
env=self.env
)
results = algo.run(self.data_portal)
# On 1/07, positions value should by 780, same as without split
self.assertEqual(results.pos_value.iloc[0], 0)
self.assertEqual(results.pos_value.iloc[1], 780)
# On 1/07, after applying the split, 1 share becomes 2
self.assertEqual(results.pos_amount.iloc[0], 0)
self.assertEqual(results.pos_amount.iloc[1], 2)
# On 1/07, after applying the split, last sale price is halved
self.assertEqual(results.last_sale_price.iloc[0], 0)
self.assertEqual(results.last_sale_price.iloc[1], 390)
def test_account_bts_with_overnight_split(self):
algo_code = dedent("""
from zipline.api import order, sid, record
def initialize(context):
context.ordered = False
context.hd_account = context.account
def before_trading_start(context, data):
bts_account = context.account
# Assert that the account in BTS is the same as the last account
# in handle_data
assert (context.hd_account == bts_account)
record(port_value=bts_account.equity_with_loan)
def handle_data(context, data):
if not context.ordered:
order(sid(1), 1)
context.ordered = True
context.hd_account = context.account
""")
algo = TradingAlgorithm(
script=algo_code,
data_frequency="minute",
sim_params=self.sim_params,
env=self.env
)
results = algo.run(self.data_portal)
# On 1/07, portfolio value is the same as without split
self.assertEqual(results.port_value.iloc[0], 10000)
self.assertAlmostEqual(results.port_value.iloc[1],
10000 + 780 - 392 - 1)
class TestAlgoScript(TestCase):
+19 -18
View File
@@ -190,7 +190,8 @@ def calculate_results(sim_params,
pass
msg = perf_tracker.handle_market_close_daily(date)
msg['account'] = perf_tracker.get_account(True, date)
perf_tracker.position_tracker.sync_last_sale_prices(date, False)
msg['account'] = perf_tracker.get_account(True)
results.append(copy.deepcopy(msg))
return results
@@ -1251,7 +1252,7 @@ class TestPositionPerformance(unittest.TestCase):
pp.handle_execution(txn2)
dt = trades_1[-2].dt
pt.sync_last_sale_prices(dt)
pt.sync_last_sale_prices(dt, False)
pp.calculate_performance()
@@ -1279,7 +1280,7 @@ class TestPositionPerformance(unittest.TestCase):
net_liquidation=1000.0)
dt = trades_1[-1].dt
pt.sync_last_sale_prices(dt)
pt.sync_last_sale_prices(dt, False)
pp.calculate_performance()
@@ -1354,7 +1355,7 @@ class TestPositionPerformance(unittest.TestCase):
shorts_count=0)
# Validate that the account attributes were updated.
pt.sync_last_sale_prices(trades[-2].dt)
pt.sync_last_sale_prices(trades[-2].dt, False)
# Validate that the account attributes were updated.
account = pp.as_account()
@@ -1372,7 +1373,7 @@ class TestPositionPerformance(unittest.TestCase):
net_liquidation=1000.0)
# now simulate a price jump to $11
pt.sync_last_sale_prices(trades[-1].dt)
pt.sync_last_sale_prices(trades[-1].dt, False)
pp.calculate_performance()
@@ -1443,7 +1444,7 @@ class TestPositionPerformance(unittest.TestCase):
# stocks with a last sale price of 0.
self.assertEqual(pp.positions[1].last_sale_price, 10.0)
pt.sync_last_sale_prices(trades[-1].dt)
pt.sync_last_sale_prices(trades[-1].dt, False)
pp.calculate_performance()
@@ -1554,7 +1555,7 @@ single short-sale transaction"""
pt.execute_transaction(txn)
pp.handle_execution(txn)
pt.sync_last_sale_prices(trades_1[-1].dt)
pt.sync_last_sale_prices(trades_1[-1].dt, False)
pp.calculate_performance()
@@ -1610,7 +1611,7 @@ single short-sale transaction"""
# simulate a rollover to a new period
pp.rollover()
pt.sync_last_sale_prices(trades[-1].dt)
pt.sync_last_sale_prices(trades[-1].dt, False)
pp.calculate_performance()
@@ -1674,7 +1675,7 @@ single short-sale transaction"""
ptTotal.execute_transaction(txn)
ppTotal.handle_execution(txn)
ptTotal.sync_last_sale_prices(trades[-1].dt)
ptTotal.sync_last_sale_prices(trades[-1].dt, False)
ppTotal.calculate_performance()
@@ -1794,7 +1795,7 @@ cost of sole txn in test"
# stocks with a last sale price of 0.
self.assertEqual(pp.positions[3].last_sale_price, 10.0)
pt.sync_last_sale_prices(trades[-1].dt)
pt.sync_last_sale_prices(trades[-1].dt, False)
pp.calculate_performance()
self.assertEqual(
@@ -1908,7 +1909,7 @@ single short-sale transaction"""
pt.execute_transaction(txn)
pp.handle_execution(txn)
pt.sync_last_sale_prices(trades[-3].dt)
pt.sync_last_sale_prices(trades[-3].dt, False)
pp.calculate_performance()
self.assertEqual(
@@ -1968,7 +1969,7 @@ single short-sale transaction"""
# simulate a rollover to a new period
pp.rollover()
pt.sync_last_sale_prices(trades_2[-1].dt)
pt.sync_last_sale_prices(trades_2[-1].dt, False)
pp.calculate_performance()
self.assertEqual(
@@ -2034,13 +2035,13 @@ single short-sale transaction"""
ppTotal.position_tracker = ptTotal
for trade in trades_1:
ptTotal.sync_last_sale_prices(trade.dt)
ptTotal.sync_last_sale_prices(trade.dt, False)
ptTotal.execute_transaction(txn)
ppTotal.handle_execution(txn)
for trade in trades_2:
ptTotal.sync_last_sale_prices(trade.dt)
ptTotal.sync_last_sale_prices(trade.dt, False)
ppTotal.calculate_performance()
@@ -2155,7 +2156,7 @@ trade after cover"""
pt.execute_transaction(cover_txn)
pp.handle_execution(cover_txn)
pt.sync_last_sale_prices(trades[-1].dt)
pt.sync_last_sale_prices(trades[-1].dt, False)
pp.calculate_performance()
@@ -2263,7 +2264,7 @@ shares in position"
"should have a cost basis of 11"
)
pt.sync_last_sale_prices(dt)
pt.sync_last_sale_prices(dt, False)
pp.calculate_performance()
@@ -2280,7 +2281,7 @@ shares in position"
pp.handle_execution(sale_txn)
dt = down_tick.dt
pt.sync_last_sale_prices(dt)
pt.sync_last_sale_prices(dt, False)
pp.calculate_performance()
self.assertEqual(
@@ -2316,7 +2317,7 @@ shares in position"
pp3.handle_execution(sale_txn)
trades.append(down_tick)
pt3.sync_last_sale_prices(trades[-1].dt)
pt3.sync_last_sale_prices(trades[-1].dt, False)
pp3.calculate_performance()
self.assertEqual(
+6 -4
View File
@@ -1056,9 +1056,10 @@ class TradingAlgorithm(object):
def updated_portfolio(self):
if self.portfolio_needs_update:
self.perf_tracker.position_tracker.sync_last_sale_prices(
self.datetime, self._in_before_trading_start)
self._portfolio = \
self.perf_tracker.get_portfolio(self.performance_needs_update,
self.datetime)
self.perf_tracker.get_portfolio(self.performance_needs_update)
self.portfolio_needs_update = False
self.performance_needs_update = False
return self._portfolio
@@ -1069,9 +1070,10 @@ class TradingAlgorithm(object):
def updated_account(self):
if self.account_needs_update:
self.perf_tracker.position_tracker.sync_last_sale_prices(
self.datetime, self._in_before_trading_start)
self._account = \
self.perf_tracker.get_account(self.performance_needs_update,
self.datetime)
self.perf_tracker.get_account(self.performance_needs_update)
self.account_needs_update = False
self.performance_needs_update = False
return self._account
@@ -372,15 +372,28 @@ class PositionTracker(object):
positions.append(pos.to_dict())
return positions
def sync_last_sale_prices(self, dt):
def sync_last_sale_prices(self, dt, handle_non_market_minutes):
data_portal = self._data_portal
for asset, position in iteritems(self.positions):
last_sale_price = data_portal.get_spot_value(
asset, 'price', dt, self.data_frequency
)
if not handle_non_market_minutes:
for asset, position in iteritems(self.positions):
last_sale_price = data_portal.get_spot_value(
asset, 'price', dt, self.data_frequency
)
if not np.isnan(last_sale_price):
position.last_sale_price = last_sale_price
if not np.isnan(last_sale_price):
position.last_sale_price = last_sale_price
else:
for asset, position in iteritems(self.positions):
last_sale_price = data_portal.get_adjusted_value(
asset,
'price',
data_portal.env.previous_market_minute(dt),
dt,
self.data_frequency
)
if not np.isnan(last_sale_price):
position.last_sale_price = last_sale_price
def stats(self):
amounts = []
+6 -8
View File
@@ -190,9 +190,8 @@ class PerformanceTracker(object):
self.saved_dt = date
self.todays_performance.period_close = self.saved_dt
def get_portfolio(self, performance_needs_update, dt):
def get_portfolio(self, performance_needs_update):
if performance_needs_update:
self.position_tracker.sync_last_sale_prices(dt)
self.update_performance()
self.account_needs_update = True
return self.cumulative_performance.as_portfolio()
@@ -202,9 +201,8 @@ class PerformanceTracker(object):
self.cumulative_performance.calculate_performance()
self.todays_performance.calculate_performance()
def get_account(self, performance_needs_update, dt):
def get_account(self, performance_needs_update):
if performance_needs_update:
self.position_tracker.sync_last_sale_prices(dt)
self.update_performance()
self.account_needs_update = True
if self.account_needs_update:
@@ -329,10 +327,10 @@ class PerformanceTracker(object):
A tuple of the minute perf packet and daily perf packet.
If the market day has not ended, the daily perf packet is None.
"""
self.position_tracker.sync_last_sale_prices(dt)
self.position_tracker.sync_last_sale_prices(dt, False)
self.update_performance()
todays_date = normalize_date(dt)
account = self.get_account(False, dt)
account = self.get_account(False)
bench_returns = self.all_benchmark_returns.loc[todays_date:dt]
# cumulative returns
@@ -357,10 +355,10 @@ class PerformanceTracker(object):
Function called after handle_data when running with daily emission
rate.
"""
self.position_tracker.sync_last_sale_prices(dt)
self.position_tracker.sync_last_sale_prices(dt, False)
self.update_performance()
completed_date = self.day
account = self.get_account(False, dt)
account = self.get_account(False)
benchmark_value = self.all_benchmark_returns[completed_date]
+3 -3
View File
@@ -166,9 +166,6 @@ class AlgorithmSimulator(object):
# before cleaning up expired assets.
self._cleanup_expired_assets(midnight_dt, position_assets)
# call before trading start
algo.before_trading_start(current_data)
perf_tracker = algo.perf_tracker
# handle any splits that impact any positions or any open orders.
@@ -183,6 +180,9 @@ class AlgorithmSimulator(object):
algo.blotter.process_splits(splits)
perf_tracker.position_tracker.handle_splits(splits)
# call before trading start
algo.before_trading_start(current_data)
def handle_benchmark(date):
algo.perf_tracker.all_benchmark_returns[date] = \
self.benchmark_source.get_value(date)