mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 22:35:46 +08:00
TST:adds algo unittest for ClOSE_POSITON event type, adds commission as parameter to TestAlgorithm
This commit is contained in:
@@ -44,6 +44,7 @@ from zipline.test_algorithms import (
|
||||
EmptyPositionsAlgorithm,
|
||||
InvalidOrderAlgorithm,
|
||||
RecordAlgorithm,
|
||||
TestAlgorithm,
|
||||
TestOrderAlgorithm,
|
||||
TestOrderInstantAlgorithm,
|
||||
TestOrderPercentAlgorithm,
|
||||
@@ -90,6 +91,10 @@ from zipline.finance.trading import SimulationParameters
|
||||
from zipline.utils.api_support import set_algo_instance
|
||||
from zipline.utils.events import DateRuleFactory, TimeRuleFactory
|
||||
from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.finance import trading
|
||||
from zipline.protocol import DATASOURCE_TYPE
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
from zipline.finance.commission import PerShare
|
||||
|
||||
|
||||
class TestRecordAlgorithm(TestCase):
|
||||
@@ -1178,3 +1183,46 @@ class TestAccountControls(TestCase):
|
||||
|
||||
algo = SetMaxLeverageAlgorithm(1)
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
|
||||
class TestClosePosAlgo(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
days = TradingEnvironment().trading_days
|
||||
self.index = [days[0], days[1], days[2]]
|
||||
pan = pd.Panel({1: pd.DataFrame({
|
||||
'price': [1, 2, 4], 'volume': [1e9, 0, 0],
|
||||
'type': [DATASOURCE_TYPE.TRADE,
|
||||
DATASOURCE_TYPE.TRADE,
|
||||
DATASOURCE_TYPE.CLOSE_POSITION]},
|
||||
index=self.index)
|
||||
})
|
||||
|
||||
self.data = DataPanelSource(pan)
|
||||
self.algo = TestAlgorithm(sid=1, amount=1, order_count=1,
|
||||
instant_fill=True, commission=PerShare(0))
|
||||
self.results = self.run_algo()
|
||||
self.expected_positions = [1, 1, 0]
|
||||
self.expected_pnl = [0, 1, 2]
|
||||
|
||||
def run_algo(self):
|
||||
results = self.algo.run(self.data)
|
||||
return results
|
||||
|
||||
def test_algo_pnl(self):
|
||||
for i, pnl in enumerate(self.results.pnl):
|
||||
self.assertEqual(pnl, self.expected_pnl[i])
|
||||
|
||||
def test_algo_positions(self):
|
||||
for i, amount in enumerate(self.results.positions):
|
||||
if amount:
|
||||
actual_position = amount[0]['amount']
|
||||
else:
|
||||
actual_position = 0
|
||||
|
||||
self.assertEqual(actual_position, self.expected_positions[i])
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
self.algo = None
|
||||
trading.environment = None
|
||||
|
||||
@@ -108,7 +108,8 @@ class TestAlgorithm(TradingAlgorithm):
|
||||
amount,
|
||||
order_count,
|
||||
sid_filter=None,
|
||||
slippage=None):
|
||||
slippage=None,
|
||||
commission=None):
|
||||
self.count = order_count
|
||||
self.sid = sid
|
||||
self.amount = amount
|
||||
@@ -122,6 +123,9 @@ class TestAlgorithm(TradingAlgorithm):
|
||||
if slippage is not None:
|
||||
self.set_slippage(slippage)
|
||||
|
||||
if commission is not None:
|
||||
self.set_commission(commission)
|
||||
|
||||
def handle_data(self, data):
|
||||
# place an order for amount shares of sid
|
||||
if self.incr < self.count:
|
||||
|
||||
Reference in New Issue
Block a user