From 29561e2b0c225ad8fe271f45b79fb0442c4e3436 Mon Sep 17 00:00:00 2001 From: Juan Pablo Amoroso Date: Thu, 6 Feb 2020 18:32:39 -0300 Subject: [PATCH] Moved initial capital to Backtest constructor. Fixed Schema fields --- asset_backtester/backtester.py | 8 ++++---- backtester/datahandler/schema.py | 11 +++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/asset_backtester/backtester.py b/asset_backtester/backtester.py index f8008d8..7e58898 100644 --- a/asset_backtester/backtester.py +++ b/asset_backtester/backtester.py @@ -1,14 +1,14 @@ import pandas as pd import pyprind -import numpy as np from .portfolio import Portfolio class Backtest: - def __init__(self, schema): + def __init__(self, schema, initial_capital=1_000_000): self.schema = schema self._portfolio = None self._data = None + self.initial_capital = initial_capital @property def portfolio(self): @@ -27,13 +27,13 @@ class Backtest: def data(self, data): self._data = data - def run(self, initial_capital=1_000_000, periods=1, sma_days=None): + def run(self, periods=1, sma_days=None): """Runs a backtest and returns a dataframe with the daily balance""" assert self._data is not None assert self._portfolio is not None self.current_capital = 0 - self.current_cash = initial_capital + self.current_cash = self.initial_capital self.inventory = pd.DataFrame(columns=['symbol', 'cost', 'qty']) self.balance = pd.DataFrame() if sma_days: diff --git a/backtester/datahandler/schema.py b/backtester/datahandler/schema.py index 99edfce..7b46e6a 100644 --- a/backtester/datahandler/schema.py +++ b/backtester/datahandler/schema.py @@ -4,8 +4,8 @@ class Schema: """ columns = [ - "underlying", "underlying_last", "date", "contract", "type", - "expiration", "strike", "bid", "ask", "volume", "open_interest" + "underlying", "underlying_last", "date", "contract", "type", "expiration", "strike", "bid", "ask", "volume", + "open_interest" ] def canonical(): @@ -42,8 +42,7 @@ class Schema: return iter(self._mappings.items()) def __repr__(self): - return "Schema({})".format( - [Field(k, m) for k, m in self._mappings.items()]) + return "Schema({})".format([Field(k, m) for k, m in self._mappings.items()]) def __eq__(self, other): return self._mappings == other._mappings @@ -60,6 +59,7 @@ class Field: def _create_filter(self, op, other): if isinstance(other, Field): + query = Field._format_query(self.mapping, op, other.mapping) else: query = Field._format_query(self.mapping, op, other) @@ -68,8 +68,7 @@ class Field: def _combine_fields(self, op, other, invert=False): if isinstance(other, Field): name = Field._format_query(self.name, op, other.name, invert) - mapping = Field._format_query(self.mapping, op, other.mapping, - invert) + mapping = Field._format_query(self.mapping, op, other.mapping, invert) elif isinstance(other, (int, float)): name = Field._format_query(self.name, op, other, invert) mapping = Field._format_query(self.mapping, op, other, invert)