mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 09:34:49 +08:00
Removes done message.
Instead of checking for 'DONE' on each call uses generators builtin StopIteration for signalling the end of input.
This commit is contained in:
@@ -27,7 +27,6 @@ from operator import attrgetter
|
||||
import zipline.utils.factory as factory
|
||||
import zipline.finance.performance as perf
|
||||
from zipline.utils.protocol_utils import ndict
|
||||
from zipline.protocol import Event
|
||||
|
||||
from zipline.gens.composites import date_sorted_sources
|
||||
|
||||
@@ -666,13 +665,6 @@ class TestPerformanceTracker(unittest.TestCase):
|
||||
# Extract events with transactions to use for verification.
|
||||
events_with_txns = [event for event in events if event.TRANSACTION]
|
||||
|
||||
done_message = Event({
|
||||
'dt': 'DONE',
|
||||
'TRANSACTION': None
|
||||
})
|
||||
|
||||
events = itertools.chain(events, [done_message])
|
||||
|
||||
perf_messages = \
|
||||
[msg for date, snapshot in
|
||||
perf_tracker.transform(
|
||||
@@ -680,6 +672,10 @@ class TestPerformanceTracker(unittest.TestCase):
|
||||
for event in snapshot
|
||||
for msg in event.perf_messages]
|
||||
|
||||
end_perf_messages, risk_message = perf_tracker.handle_simulation_end()
|
||||
|
||||
perf_messages.extend(end_perf_messages)
|
||||
|
||||
#we skip two trades, to test case of None transaction
|
||||
self.assertEqual(perf_tracker.txn_count, len(events_with_txns))
|
||||
|
||||
@@ -696,7 +692,7 @@ class TestPerformanceTracker(unittest.TestCase):
|
||||
def event_with_txn(self, event, no_txn_dt):
|
||||
#create a transaction for all but
|
||||
#first trade in each sid, to simulate None transaction
|
||||
if event.dt != no_txn_dt and event.dt != 'DONE':
|
||||
if event.dt != no_txn_dt:
|
||||
txn = ndict({
|
||||
'sid': event.sid,
|
||||
'amount': -25,
|
||||
|
||||
@@ -219,12 +219,8 @@ class PerformanceTracker(object):
|
||||
new_snapshot = []
|
||||
|
||||
for event in snapshot:
|
||||
if date != "DONE":
|
||||
event.perf_messages = self.process_event(event)
|
||||
event.portfolio = self.get_portfolio()
|
||||
else:
|
||||
event.perf_messages, event.risk_message = \
|
||||
self.handle_simulation_end()
|
||||
event.perf_messages = self.process_event(event)
|
||||
event.portfolio = self.get_portfolio()
|
||||
|
||||
del event['TRANSACTION']
|
||||
new_snapshot.append(event)
|
||||
|
||||
@@ -15,10 +15,6 @@
|
||||
|
||||
import heapq
|
||||
|
||||
from itertools import chain
|
||||
|
||||
from zipline.gens.utils import done_message
|
||||
|
||||
|
||||
def decorate_source(source):
|
||||
for message in source:
|
||||
@@ -55,8 +51,7 @@ def sequential_transforms(stream_in, *transforms):
|
||||
transforms,
|
||||
stream_in)
|
||||
|
||||
dt_aliased = alias_dt(stream_out)
|
||||
return add_done(dt_aliased)
|
||||
return alias_dt(stream_out)
|
||||
|
||||
|
||||
def alias_dt(stream_in):
|
||||
@@ -66,8 +61,3 @@ def alias_dt(stream_in):
|
||||
for message in stream_in:
|
||||
message['datetime'] = message['dt']
|
||||
yield message
|
||||
|
||||
|
||||
# Add a done message to a stream.
|
||||
def add_done(stream_in):
|
||||
return chain(stream_in, [done_message('Composite')])
|
||||
|
||||
@@ -71,6 +71,7 @@ class TradeSimulationClient(object):
|
||||
self.algo_start = self.environment.first_open
|
||||
self.algo_sim = AlgorithmSimulator(
|
||||
self.ordering_client,
|
||||
self.perf_tracker,
|
||||
self.algo,
|
||||
self.algo_start
|
||||
)
|
||||
@@ -116,6 +117,7 @@ class AlgorithmSimulator(object):
|
||||
|
||||
def __init__(self,
|
||||
order_book,
|
||||
perf_tracker,
|
||||
algo,
|
||||
algo_start):
|
||||
|
||||
@@ -126,6 +128,7 @@ class AlgorithmSimulator(object):
|
||||
# We extract the order book from the txn client so that
|
||||
# the algo can place new orders.
|
||||
self.order_book = order_book
|
||||
self.perf_tracker = perf_tracker
|
||||
|
||||
self.algo = algo
|
||||
self.algo_start = algo_start.replace(hour=0, minute=0,
|
||||
@@ -203,18 +206,10 @@ class AlgorithmSimulator(object):
|
||||
if self.simulation_dt is None:
|
||||
self.simulation_dt = date
|
||||
|
||||
# Done message has the risk report, so we yield before exiting.
|
||||
if date == 'DONE':
|
||||
for event in snapshot:
|
||||
for perf_message in event.perf_messages:
|
||||
yield perf_message
|
||||
yield event.risk_message
|
||||
raise StopIteration
|
||||
|
||||
# We're still in the warmup period. Use the event to
|
||||
# update our universe, but don't yield any perf messages,
|
||||
# and don't send a snapshot to handle_data.
|
||||
elif date < self.algo_start:
|
||||
if date < self.algo_start:
|
||||
for event in snapshot:
|
||||
del event['perf_messages']
|
||||
self.update_universe(event)
|
||||
@@ -233,6 +228,14 @@ class AlgorithmSimulator(object):
|
||||
# to the user's algo.
|
||||
self.simulate_snapshot(date)
|
||||
|
||||
perf_messages, risk_message = \
|
||||
self.perf_tracker.handle_simulation_end()
|
||||
|
||||
for message in perf_messages:
|
||||
yield message
|
||||
|
||||
yield risk_message
|
||||
|
||||
def update_universe(self, event):
|
||||
"""
|
||||
Update the universe with new event information.
|
||||
|
||||
@@ -20,7 +20,6 @@ import numbers
|
||||
from hashlib import md5
|
||||
from datetime import datetime
|
||||
from itertools import izip_longest
|
||||
from zipline import ndict
|
||||
from zipline.protocol import (
|
||||
DATASOURCE_TYPE,
|
||||
Event
|
||||
@@ -37,18 +36,6 @@ def mock_raw_event(sid, dt):
|
||||
return event
|
||||
|
||||
|
||||
def mock_done(id):
|
||||
return ndict({
|
||||
'dt': "DONE",
|
||||
"source_id": id,
|
||||
'tnfm_id': id,
|
||||
'tnfm_value': None,
|
||||
'type': DATASOURCE_TYPE.DONE
|
||||
})
|
||||
|
||||
done_message = mock_done
|
||||
|
||||
|
||||
def alternate(g1, g2):
|
||||
"""Specialized version of roundrobin for just 2 generators."""
|
||||
for e1, e2 in izip_longest(g1, g2):
|
||||
|
||||
Reference in New Issue
Block a user