diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 62203612..a2688b28 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -154,6 +154,7 @@ from zipline.test_algorithms import ( call_with_bad_kwargs_get_open_orders, call_with_good_kwargs_get_open_orders, call_with_no_kwargs_get_open_orders, + empty_positions, no_handle_data, ) from zipline.utils.api_support import ZiplineAPI, set_algo_instance @@ -1799,6 +1800,25 @@ def handle_data(context, data): else: algo.run(self.data_portal) + def test_empty_positions(self): + """ + Test that when we try context.portfolio.positions[stock] on a stock + for which we have no positions, we return a Position with values 0 + (but more importantly, we don't crash) and don't save this Position + to the user-facing dictionary PositionTracker._positions_store + """ + algo = TradingAlgorithm( + script=empty_positions, + sim_params=self.sim_params, + env=self.env + ) + + results = algo.run(self.data_portal) + num_positions = results.num_positions + amounts = results.amounts + self.assertTrue(all(num_positions == 0)) + self.assertTrue(all(amounts == 0)) + class TestGetDatetime(WithLogger, WithSimParams, diff --git a/zipline/finance/performance/position_tracker.py b/zipline/finance/performance/position_tracker.py index ea1600a1..1893e489 100644 --- a/zipline/finance/performance/position_tracker.py +++ b/zipline/finance/performance/position_tracker.py @@ -350,14 +350,16 @@ class PositionTracker(object): pass continue - # Note that this will create a position if we don't currently have - # an entry - position = positions[sid] + position = zp.Position(sid) position.amount = pos.amount position.cost_basis = pos.cost_basis position.last_sale_price = pos.last_sale_price position.last_sale_date = pos.last_sale_date + # Adds the new position if we didn't have one before, or overwrite + # one we have currently + positions[sid] = position + return positions def get_positions_list(self): diff --git a/zipline/protocol.py b/zipline/protocol.py index 9db94914..c7915590 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -164,5 +164,4 @@ class Positions(dict): def __missing__(self, key): pos = Position(key) - self[key] = pos return pos diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index bc2ea6dd..39c4151b 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -1126,3 +1126,17 @@ def initialize(context): def handle_data(context, data): context.get_open_orders(symbol('TEST')) """ + +empty_positions = """ +from zipline.api import record, schedule_function, time_rules, date_rules, \ + symbol + +def initialize(context): + schedule_function(test_history, date_rules.every_day(), + time_rules.market_open(hours=1)) + context.sid = symbol('TEST') + +def test_history(context,data): + record(amounts=context.portfolio.positions[context.sid].amount) + record(num_positions=len(context.portfolio.positions)) +"""