diff --git a/backtester/backtester.py b/backtester/backtester.py index 2eaf95d..e87cccd 100644 --- a/backtester/backtester.py +++ b/backtester/backtester.py @@ -113,7 +113,8 @@ class Backtest: """ leg_candidates = [ - self._strategy._exit_candidates(l.direction, self.inventory[l.name], options) for l in self._strategy.legs + self._strategy._exit_candidates(l.direction, self.inventory[l.name], options, self.inventory.index) + for l in self._strategy.legs ] # If a contract is missing we replace the NaN values with those of the inventory diff --git a/backtester/strategy/strategy.py b/backtester/strategy/strategy.py index 522657d..4808ab2 100644 --- a/backtester/strategy/strategy.py +++ b/backtester/strategy/strategy.py @@ -101,16 +101,18 @@ class Strategy: pd.DataFrame: Exit signals """ - leg_candidates = [self._exit_candidates(l.direction, inventory[l.name], options) for l in self.legs] + leg_candidates = [ + self._exit_candidates(l.direction, inventory[l.name], options, inventory.index) for l in self.legs + ] - filter_mask = [] + filter_masks = [] for i, leg in enumerate(self.legs): flt = leg.exit_filter # This mask is to ensure that legs with missing contracts exit. missing_contracts_mask = leg_candidates[i]['cost'].isna() - filter_mask.append(flt(leg_candidates[i]) | missing_contracts_mask) + filter_masks.append(flt(leg_candidates[i]) | missing_contracts_mask) fields = self._signal_fields((~leg.direction).value) leg_candidates[i] = leg_candidates[i].loc[:, fields.values()] leg_candidates[i].columns = pd.MultiIndex.from_product([["leg_{}".format(i + 1)], @@ -133,7 +135,7 @@ class Strategy: # Compute which contracts need to exit, either because of price thresholds or user exit filters threshold_exits = self._filter_thresholds(inventory['totals']['cost'], total_costs) - filter_mask = reduce(lambda x, y: x | y, filter_mask) + filter_mask = reduce(lambda x, y: x | y, filter_masks) exits_mask = threshold_exits | filter_mask exits = candidates[exits_mask] @@ -226,7 +228,7 @@ class Strategy: return pd.concat(dfs, axis=1) - def _exit_candidates(self, direction, inventory_leg, options): + def _exit_candidates(self, direction, inventory_leg, options, inventory_index): """Returns the exit candidates for the given inventory leg with their order and cost (positive for STC orders). Args: @@ -244,6 +246,9 @@ class Strategy: fields = self._signal_fields((~direction).value) options = options.rename(columns=fields) candidates = inventory_leg[['contract']].merge(options, how='left', on='contract') + # candidates.index needs to be the same as the inventory's so that the exit masks that are constructed + # from it can be correctly applied to the inventory. + candidates.index = inventory_index order = get_order(direction, Signal.EXIT) candidates['order'] = order