mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 21:44:25 +08:00
@@ -1891,3 +1891,34 @@ class TestFutureFlip(TestCase):
|
||||
actual_position, expected_positions[i],
|
||||
"position for day={0} not equal, actual={1}, expected={2}".
|
||||
format(i, actual_position, expected_positions[i]))
|
||||
|
||||
|
||||
class TestTradingAlgorithm(TestCase):
|
||||
def setUp(self):
|
||||
self.env = TradingEnvironment()
|
||||
self.days = self.env.trading_days[:4]
|
||||
self.panel = pd.Panel({1: pd.DataFrame({
|
||||
'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 0],
|
||||
'type': [DATASOURCE_TYPE.TRADE,
|
||||
DATASOURCE_TYPE.TRADE,
|
||||
DATASOURCE_TYPE.TRADE,
|
||||
DATASOURCE_TYPE.CLOSE_POSITION]},
|
||||
index=self.days)
|
||||
})
|
||||
|
||||
def test_analyze_called(self):
|
||||
self.perf_ref = None
|
||||
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
pass
|
||||
|
||||
def analyze(context, perf):
|
||||
self.perf_ref = perf
|
||||
|
||||
algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
|
||||
analyze=analyze)
|
||||
results = algo.run(self.panel)
|
||||
self.assertIs(results, self.perf_ref)
|
||||
|
||||
@@ -297,6 +297,7 @@ class TradingAlgorithm(object):
|
||||
self._handle_data = kwargs.pop('handle_data')
|
||||
self._before_trading_start = kwargs.pop('before_trading_start',
|
||||
None)
|
||||
self._analyze = kwargs.pop('analyze', None)
|
||||
|
||||
self.event_manager.add_event(
|
||||
zipline.utils.events.Event(
|
||||
|
||||
Reference in New Issue
Block a user