diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index efafe40d..dc9f890b 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -2085,6 +2085,16 @@ def order_stuff(context, data): cumulative_perf = \ [r['cumulative_perf'] for r in results if 'cumulative_perf' in r] daily_perf = [r['daily_perf'] for r in results if 'daily_perf' in r] + capital_change_packets = \ + [r['capital_change'] for r in results if 'capital_change' in r] + + self.assertEqual(len(capital_change_packets), 1) + self.assertEqual( + capital_change_packets[0], + {'date': pd.Timestamp('2006-01-06', tz='UTC'), + 'type': 'cash', + 'target': 153000.0 if change_type == 'target' else None, + 'delta': 50000.0}) # 1/03: price = 10, place orders # 1/04: orders execute at price = 11, place orders @@ -2235,6 +2245,17 @@ def order_stuff(context, data): cumulative_perf = \ [r['cumulative_perf'] for r in results if 'cumulative_perf' in r] daily_perf = [r['daily_perf'] for r in results if 'daily_perf' in r] + capital_change_packets = \ + [r['capital_change'] for r in results if 'capital_change' in r] + + self.assertEqual(len(capital_change_packets), len(capital_changes)) + expected = [ + {'date': pd.Timestamp(val[0], tz='UTC'), + 'type': 'cash', + 'target': val[1] if change_type == 'target' else None, + 'delta': 1000.0 if len(values) == 1 else 500.0} + for val in values] + self.assertEqual(capital_change_packets, expected) # 1/03: place orders at price = 100, execute at 101 # 1/04: place orders at price = 490, execute at 491, @@ -2392,6 +2413,17 @@ def order_stuff(context, data): [r['cumulative_perf'] for r in results if 'cumulative_perf' in r] minute_perf = [r['minute_perf'] for r in results if 'minute_perf' in r] daily_perf = [r['daily_perf'] for r in results if 'daily_perf' in r] + capital_change_packets = \ + [r['capital_change'] for r in results if 'capital_change' in r] + + self.assertEqual(len(capital_change_packets), len(capital_changes)) + expected = [ + {'date': pd.Timestamp(val[0], tz='UTC'), + 'type': 'cash', + 'target': val[1] if change_type == 'target' else None, + 'delta': 1000.0 if len(values) == 1 else 500.0} + for val in values] + self.assertEqual(capital_change_packets, expected) # 1/03: place orders at price = 100, execute at 101 # 1/04: place orders at price = 490, execute at 491, diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 0a064703..d163e34e 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -847,14 +847,16 @@ class TradingAlgorithm(object): self.perf_tracker.prepare_capital_change(is_interday) if capital_change['type'] == 'target': - capital_change_amount = capital_change['value'] - \ + target = capital_change['value'] + capital_change_amount = target - \ self.updated_portfolio().portfolio_value self.portfolio_needs_update = True log.info('Processing capital change to target %s at %s. Capital ' - 'change delta is %s' % (capital_change['value'], dt, + 'change delta is %s' % (target, dt, capital_change_amount)) elif capital_change['type'] == 'delta': + target = None capital_change_amount = capital_change['value'] log.info('Processing capital change of delta %s at %s' % (capital_change_amount, dt)) @@ -867,6 +869,14 @@ class TradingAlgorithm(object): self.perf_tracker.process_capital_change(capital_change_amount, is_interday) + yield { + 'capital_change': + {'date': dt, + 'type': 'cash', + 'target': target, + 'delta': capital_change_amount} + } + @api_method def get_environment(self, field='platform'): """Query the execution environment. diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index 4aa1c2fa..cd5eb320 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -102,7 +102,8 @@ class AlgorithmSimulator(object): # called every tick (minute or day). algo.on_dt_changed(dt_to_use) - calculate_minute_capital_changes(dt_to_use) + for capital_change in calculate_minute_capital_changes(dt_to_use): + yield capital_change self.simulation_dt = dt_to_use @@ -159,8 +160,10 @@ class AlgorithmSimulator(object): algo.on_dt_changed(midnight_dt) # process any capital changes that came overnight - algo.calculate_capital_changes( - midnight_dt, emission_rate=emission_rate, is_interday=True) + for capital_change in algo.calculate_capital_changes( + midnight_dt, emission_rate=emission_rate, + is_interday=True): + yield capital_change # we want to wait until the clock rolls over to the next day # before cleaning up expired assets. @@ -204,20 +207,22 @@ class AlgorithmSimulator(object): def calculate_minute_capital_changes(dt): # process any capital changes that came between the last # and current minutes - algo.calculate_capital_changes( + return algo.calculate_capital_changes( dt, emission_rate=emission_rate, is_interday=False) else: def execute_order_cancellation_policy(): pass def calculate_minute_capital_changes(dt): - pass + return [] for dt, action in self.clock: if action == BAR: - every_bar(dt) + for capital_change_packet in every_bar(dt): + yield capital_change_packet elif action == DAY_START: - once_a_day(dt) + for capital_change_packet in once_a_day(dt): + yield capital_change_packet elif action == DAY_END: # End of the day. if emission_rate == 'daily':