mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 17:21:10 +08:00
API: Add slippage and commission models for futures
This commit is contained in:
@@ -56,6 +56,7 @@ from zipline.data.us_equity_pricing import (
|
||||
from zipline.errors import (
|
||||
AccountControlViolation,
|
||||
CannotOrderDelistedAsset,
|
||||
IncompatibleSlippageModel,
|
||||
OrderDuringInitialize,
|
||||
OrderInBeforeTradingStart,
|
||||
RegisterTradingControlPostInit,
|
||||
@@ -1738,6 +1739,27 @@ def handle_data(context, data):
|
||||
finally:
|
||||
tempdir.cleanup()
|
||||
|
||||
def test_incorrectly_set_futures_slippage_model(self):
|
||||
code = dedent(
|
||||
"""
|
||||
from zipline.api import set_slippage, slippage
|
||||
|
||||
class MySlippage(slippage.FutureSlippageModel):
|
||||
def process_order(self, data, order):
|
||||
return data.current(order.asset, 'price'), order.amount
|
||||
|
||||
def initialize(context):
|
||||
set_slippage(MySlippage())
|
||||
"""
|
||||
)
|
||||
test_algo = TradingAlgorithm(
|
||||
script=code, sim_params=self.sim_params, env=self.env,
|
||||
)
|
||||
with self.assertRaises(IncompatibleSlippageModel):
|
||||
# Passing a futures slippage model as the first argument, which is
|
||||
# for setting equity models, should fail.
|
||||
test_algo.run(self.data_portal)
|
||||
|
||||
def test_algo_record_vars(self):
|
||||
test_algo = TradingAlgorithm(
|
||||
script=record_variables,
|
||||
@@ -3655,6 +3677,99 @@ class TestFuturesAlgo(WithDataPortal, WithSimParams, ZiplineTestCase):
|
||||
algo.history_values[1].values, list(map(float, range(3636, 3641))),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def algo_with_slippage(slippage_model):
|
||||
return dedent(
|
||||
"""
|
||||
from zipline.api import (
|
||||
commission,
|
||||
order,
|
||||
set_commission,
|
||||
set_slippage,
|
||||
sid,
|
||||
slippage,
|
||||
get_datetime,
|
||||
)
|
||||
|
||||
def initialize(context):
|
||||
commission_model = commission.PerFutureTrade(0)
|
||||
set_commission(us_futures=commission_model)
|
||||
slippage_model = slippage.{model}
|
||||
set_slippage(us_futures=slippage_model)
|
||||
context.ordered = False
|
||||
|
||||
def handle_data(context, data):
|
||||
if not context.ordered:
|
||||
order(sid(1), 10)
|
||||
context.ordered = True
|
||||
context.order_price = data.current(sid(1), 'price')
|
||||
"""
|
||||
).format(model=slippage_model)
|
||||
|
||||
def test_fixed_future_slippage(self):
|
||||
algo_code = self.algo_with_slippage('FixedSlippage(spread=0.10)')
|
||||
algo = TradingAlgorithm(
|
||||
script=algo_code,
|
||||
sim_params=self.sim_params,
|
||||
env=self.env,
|
||||
trading_calendar=get_calendar('us_futures'),
|
||||
)
|
||||
results = algo.run(self.data_portal)
|
||||
|
||||
# Flatten the list of transactions.
|
||||
all_txns = [
|
||||
val for sublist in results['transactions'].tolist()
|
||||
for val in sublist
|
||||
]
|
||||
|
||||
self.assertEqual(len(all_txns), 1)
|
||||
txn = all_txns[0]
|
||||
|
||||
# Add 1 to the expected price because the order does not fill until the
|
||||
# bar after the price is recorded.
|
||||
expected_spread = 0.05
|
||||
expected_price = (algo.order_price + 1) + expected_spread
|
||||
|
||||
# Capital used should be 0 because there is no commission, and the cost
|
||||
# to enter into a long position on a futures contract is 0.
|
||||
self.assertEqual(txn['price'], expected_price)
|
||||
self.assertEqual(results['orders'][0][0]['commission'], 0.0)
|
||||
self.assertEqual(results.capital_used[0], 0.0)
|
||||
|
||||
def test_volume_contract_slippage(self):
|
||||
algo_code = self.algo_with_slippage(
|
||||
'VolumeShareSlippage(volume_limit=0.05, price_impact=0.1)',
|
||||
)
|
||||
algo = TradingAlgorithm(
|
||||
script=algo_code,
|
||||
sim_params=self.sim_params,
|
||||
env=self.env,
|
||||
trading_calendar=get_calendar('us_futures'),
|
||||
)
|
||||
results = algo.run(self.data_portal)
|
||||
|
||||
# There should be no commissions.
|
||||
self.assertEqual(results['orders'][0][0]['commission'], 0.0)
|
||||
|
||||
# Flatten the list of transactions.
|
||||
all_txns = [
|
||||
val for sublist in results['transactions'].tolist()
|
||||
for val in sublist
|
||||
]
|
||||
|
||||
# With a volume limit of 0.05, and a total volume of 100 contracts
|
||||
# traded per minute, we should require 2 transactions to order 10
|
||||
# contracts.
|
||||
self.assertEqual(len(all_txns), 2)
|
||||
|
||||
for i, txn in enumerate(all_txns):
|
||||
# Add 1 to the order price because the order does not fill until
|
||||
# the bar after the price is recorded.
|
||||
order_price = algo.order_price + i + 1
|
||||
expected_impact = order_price * 0.1 * (0.05 ** 2)
|
||||
expected_price = order_price + expected_impact
|
||||
self.assertEqual(txn['price'], expected_price)
|
||||
|
||||
|
||||
class TestTradingAlgorithm(ZiplineTestCase):
|
||||
def test_analyze_called(self):
|
||||
|
||||
Reference in New Issue
Block a user