mirror of
https://github.com/wassname/options_backtester.git
synced 2026-06-27 18:05:27 +08:00
sma example
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import pandas as pd
|
||||
import pyprind
|
||||
|
||||
import numpy as np
|
||||
from .portfolio import Portfolio
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class Backtest:
|
||||
def data(self, data):
|
||||
self._data = data
|
||||
|
||||
def run(self, initial_capital=1_000_000, periods=1, sma_months=None):
|
||||
def run(self, initial_capital=1_000_000, periods=1, sma_days=None):
|
||||
"""Runs a backtest and returns a dataframe with the daily balance"""
|
||||
assert self._data is not None
|
||||
assert self._portfolio is not None
|
||||
@@ -36,9 +36,9 @@ class Backtest:
|
||||
self.current_cash = initial_capital
|
||||
self.inventory = pd.DataFrame(columns=['symbol', 'cost', 'qty'])
|
||||
self.balance = pd.DataFrame()
|
||||
if sma_months:
|
||||
self._data.sma(sma_months)
|
||||
|
||||
if sma_days:
|
||||
self._data.sma(sma_days)
|
||||
|
||||
data_iterator = self._data.iter_dates()
|
||||
|
||||
first_day = self._data['date'].min()
|
||||
@@ -55,11 +55,12 @@ class Backtest:
|
||||
index=[self._data.start_date - pd.Timedelta(1, unit='day')])
|
||||
|
||||
for date, data in data_iterator:
|
||||
if date == first_day:
|
||||
self._rebalance_portfolio(data)
|
||||
|
||||
if date == first_day:
|
||||
self._rebalance_portfolio(data, sma_days)
|
||||
self._update_balance(date, data)
|
||||
if date in rebalancing_days:
|
||||
self._rebalance_portfolio(data)
|
||||
self._rebalance_portfolio(data, sma_days)
|
||||
|
||||
bar.update()
|
||||
|
||||
@@ -68,23 +69,29 @@ class Backtest:
|
||||
|
||||
return self.balance
|
||||
|
||||
def _rebalance_portfolio(self, data):
|
||||
def _rebalance_portfolio(self, data, sma_days):
|
||||
"""Rebalances the portfolio so that the total money is allocated according to the given percentages"""
|
||||
money_total = self.current_cash + self.current_capital
|
||||
|
||||
for asset in self._portfolio.assets:
|
||||
query = '{} == "{}"'.format(self.schema['symbol'], asset.symbol)
|
||||
|
||||
asset_current = data.query(query)
|
||||
|
||||
asset_price = asset_current[self.schema['adjClose']].values[0]
|
||||
|
||||
qty = (money_total * asset.percentage) // asset_price
|
||||
if sma_days is not None:
|
||||
if asset_current['sma'].values[0] < asset_price:
|
||||
qty = (money_total * asset.percentage) // asset_price
|
||||
else:
|
||||
qty = 0
|
||||
|
||||
else:
|
||||
qty = (money_total * asset.percentage) // asset_price
|
||||
|
||||
inventory_entry = self.inventory.query(query)
|
||||
self.inventory.drop(inventory_entry.index, inplace=True)
|
||||
updated_asset = pd.Series([asset.symbol, asset_price, qty])
|
||||
updated_asset.index = self.inventory.columns
|
||||
self.inventory = self.inventory.append(updated_asset, ignore_index=True)
|
||||
|
||||
# Update current cash
|
||||
invested_capital = sum(self.inventory['cost'] * self.inventory['qty'])
|
||||
self.current_cash = money_total - invested_capital
|
||||
@@ -112,5 +119,3 @@ class Backtest:
|
||||
'capital': money_total,
|
||||
}, name=date)
|
||||
self.balance = self.balance.append(row)
|
||||
|
||||
|
||||
|
||||
+12
-13
@@ -63,17 +63,16 @@ def monthly_returns_heatmap(report):
|
||||
return chart
|
||||
|
||||
|
||||
def historical_values(data_sma, data_symbol, asset_name):
|
||||
def sma_graph(data):
|
||||
|
||||
asset_sma = pd.DataFrame(data_sma[asset_name])
|
||||
asset_sma = asset_sma.rename(columns={asset_name: 'value'})
|
||||
asset_sma['id'] = ['sma value'] * (len(asset_sma.index))
|
||||
asset_sma = asset_sma.dropna()
|
||||
asset_value = pd.DataFrame(data_symbol[asset_name]['Adj Close'])
|
||||
asset_value = asset_value.rename(columns={'Adj Close': 'value'})
|
||||
asset_value['id'] = ['Adj Close'] * (len(asset_value.index))
|
||||
|
||||
asset_value = asset_value.append(asset_sma)
|
||||
asset_value['index'] = asset_value.index
|
||||
plot = alt.Chart(asset_value).mark_line().encode(x='index:T', y=alt.Y('value:Q'), color='id')
|
||||
return plot
|
||||
price_chart = alt.Chart(
|
||||
data,
|
||||
width=700,
|
||||
height=350,
|
||||
).mark_line().encode(x='date:T', y=alt.Y('adjClose:Q'), color='symbol:N', opacity=alt.value(0.3))
|
||||
sma_chart = alt.Chart(
|
||||
data,
|
||||
width=700,
|
||||
height=350,
|
||||
).mark_line(strokeDash=[1, 1]).encode(x='date:T', y=alt.Y('sma:Q'), color='symbol:N')
|
||||
return price_chart + sma_chart
|
||||
@@ -70,7 +70,10 @@ class HistoricalAssetData:
|
||||
"""Returns default schema for Historical Asset Data"""
|
||||
schema = Schema.canonical()
|
||||
return schema
|
||||
def sma(self,months):
|
||||
|
||||
def sma(self, months):
|
||||
sma = self._data.groupby('symbol').rolling(months)['adjClose'].mean()
|
||||
sma = sma.reset_index('symbol').sort_index()
|
||||
sma = sma.fillna(0)
|
||||
self._data['sma'] = sma['adjClose']
|
||||
self.schema.update({'sma': 'sma'})
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user