mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 03:23:27 +08:00
BUG: Don't save empty positions when user access non-existent position
Previously, whenever we try to access a missing value on the Positions dict, we return a default Position and save it to the dict. Instead, just return the Position
This commit is contained in:
@@ -154,6 +154,7 @@ from zipline.test_algorithms import (
|
||||
call_with_bad_kwargs_get_open_orders,
|
||||
call_with_good_kwargs_get_open_orders,
|
||||
call_with_no_kwargs_get_open_orders,
|
||||
empty_positions,
|
||||
no_handle_data,
|
||||
)
|
||||
from zipline.utils.api_support import ZiplineAPI, set_algo_instance
|
||||
@@ -1799,6 +1800,25 @@ def handle_data(context, data):
|
||||
else:
|
||||
algo.run(self.data_portal)
|
||||
|
||||
def test_empty_positions(self):
|
||||
"""
|
||||
Test that when we try context.portfolio.positions[stock] on a stock
|
||||
for which we have no positions, we return a Position with values 0
|
||||
(but more importantly, we don't crash) and don't save this Position
|
||||
to the user-facing dictionary PositionTracker._positions_store
|
||||
"""
|
||||
algo = TradingAlgorithm(
|
||||
script=empty_positions,
|
||||
sim_params=self.sim_params,
|
||||
env=self.env
|
||||
)
|
||||
|
||||
results = algo.run(self.data_portal)
|
||||
num_positions = results.num_positions
|
||||
amounts = results.amounts
|
||||
self.assertTrue(all(num_positions == 0))
|
||||
self.assertTrue(all(amounts == 0))
|
||||
|
||||
|
||||
class TestGetDatetime(WithLogger,
|
||||
WithSimParams,
|
||||
|
||||
@@ -350,14 +350,16 @@ class PositionTracker(object):
|
||||
pass
|
||||
continue
|
||||
|
||||
# Note that this will create a position if we don't currently have
|
||||
# an entry
|
||||
position = positions[sid]
|
||||
position = zp.Position(sid)
|
||||
position.amount = pos.amount
|
||||
position.cost_basis = pos.cost_basis
|
||||
position.last_sale_price = pos.last_sale_price
|
||||
position.last_sale_date = pos.last_sale_date
|
||||
|
||||
# Adds the new position if we didn't have one before, or overwrite
|
||||
# one we have currently
|
||||
positions[sid] = position
|
||||
|
||||
return positions
|
||||
|
||||
def get_positions_list(self):
|
||||
|
||||
@@ -164,5 +164,4 @@ class Positions(dict):
|
||||
|
||||
def __missing__(self, key):
|
||||
pos = Position(key)
|
||||
self[key] = pos
|
||||
return pos
|
||||
|
||||
@@ -1126,3 +1126,17 @@ def initialize(context):
|
||||
def handle_data(context, data):
|
||||
context.get_open_orders(symbol('TEST'))
|
||||
"""
|
||||
|
||||
empty_positions = """
|
||||
from zipline.api import record, schedule_function, time_rules, date_rules, \
|
||||
symbol
|
||||
|
||||
def initialize(context):
|
||||
schedule_function(test_history, date_rules.every_day(),
|
||||
time_rules.market_open(hours=1))
|
||||
context.sid = symbol('TEST')
|
||||
|
||||
def test_history(context,data):
|
||||
record(amounts=context.portfolio.positions[context.sid].amount)
|
||||
record(num_positions=len(context.portfolio.positions))
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user