BUG: Don't save empty positions when user access non-existent position

Previously, whenever we try to access a missing value on the Positions
dict, we return a default Position and save it to the dict. Instead,
just return the Position
This commit is contained in:
Andrew Liang
2016-04-26 10:15:54 -04:00
parent 50f4917341
commit d69b960c49
4 changed files with 39 additions and 4 deletions
+20
View File
@@ -154,6 +154,7 @@ from zipline.test_algorithms import (
call_with_bad_kwargs_get_open_orders,
call_with_good_kwargs_get_open_orders,
call_with_no_kwargs_get_open_orders,
empty_positions,
no_handle_data,
)
from zipline.utils.api_support import ZiplineAPI, set_algo_instance
@@ -1799,6 +1800,25 @@ def handle_data(context, data):
else:
algo.run(self.data_portal)
def test_empty_positions(self):
"""
Test that when we try context.portfolio.positions[stock] on a stock
for which we have no positions, we return a Position with values 0
(but more importantly, we don't crash) and don't save this Position
to the user-facing dictionary PositionTracker._positions_store
"""
algo = TradingAlgorithm(
script=empty_positions,
sim_params=self.sim_params,
env=self.env
)
results = algo.run(self.data_portal)
num_positions = results.num_positions
amounts = results.amounts
self.assertTrue(all(num_positions == 0))
self.assertTrue(all(amounts == 0))
class TestGetDatetime(WithLogger,
WithSimParams,
@@ -350,14 +350,16 @@ class PositionTracker(object):
pass
continue
# Note that this will create a position if we don't currently have
# an entry
position = positions[sid]
position = zp.Position(sid)
position.amount = pos.amount
position.cost_basis = pos.cost_basis
position.last_sale_price = pos.last_sale_price
position.last_sale_date = pos.last_sale_date
# Adds the new position if we didn't have one before, or overwrite
# one we have currently
positions[sid] = position
return positions
def get_positions_list(self):
-1
View File
@@ -164,5 +164,4 @@ class Positions(dict):
def __missing__(self, key):
pos = Position(key)
self[key] = pos
return pos
+14
View File
@@ -1126,3 +1126,17 @@ def initialize(context):
def handle_data(context, data):
context.get_open_orders(symbol('TEST'))
"""
empty_positions = """
from zipline.api import record, schedule_function, time_rules, date_rules, \
symbol
def initialize(context):
schedule_function(test_history, date_rules.every_day(),
time_rules.market_open(hours=1))
context.sid = symbol('TEST')
def test_history(context,data):
record(amounts=context.portfolio.positions[context.sid].amount)
record(num_positions=len(context.portfolio.positions))
"""