generator-style perf now sends a risk report on receipt of DONE

This commit is contained in:
scottsanderson
2012-08-03 21:09:05 -04:00
parent 6a181ab4e7
commit 8437a28c14
9 changed files with 79 additions and 86 deletions
-1
View File
@@ -114,7 +114,6 @@ class Component(object):
# Core Methods
# ------------
def loop_send(self):
"""
The main component loop. This is wrapped inside a
+14 -11
View File
@@ -203,10 +203,15 @@ class PerformanceTracker(object):
self.todays_performance.positions[sid] = Position(sid)
def update(self, event):
event.perf_message = self.process_event(event)
event.portfolio = self.get_portfolio()
del event['TRANSACTION']
return event
if event.dt == "DONE":
event.perf_message = self.handle_simulation_end()
del event['TRANSACTION']
return event
else:
event.perf_message = self.process_event(event)
event.portfolio = self.get_portfolio()
del event['TRANSACTION']
return event
def get_portfolio(self):
return self.cumulative_performance.as_portfolio()
@@ -270,6 +275,7 @@ class PerformanceTracker(object):
#calculate performance as of last trade
self.cumulative_performance.calculate_performance()
self.todays_performance.calculate_performance()
return message
@@ -296,7 +302,8 @@ class PerformanceTracker(object):
# calculate progress of test
self.progress = self.day_count / self.total_days
#TODO TODO TODO!!
# Take a snapshot of our current peformance to return to the
# browser.
daily_update = self.to_dict()
if self.trading_environment.max_drawdown:
@@ -356,12 +363,8 @@ class PerformanceTracker(object):
exceeded_max_loss = self.exceeded_max_loss
)
if self.results_socket:
log.info("about to stream the risk report...")
risk_dict = self.risk_report.to_dict()
msg = zp.RISK_FRAME(risk_dict)
self.results_socket.send(msg)
risk_dict = self.risk_report.to_dict()
return risk_dict
class Position(object):
+1
View File
@@ -35,6 +35,7 @@ class TransactionSimulator(object):
def update(self, event):
event.TRANSACTION = None
# We only fill transactions on trade events.
if event.type == zp.DATASOURCE_TYPE.TRADE:
event.TRANSACTION = self.apply_trade_to_open_orders(event)
return event
+21 -18
View File
@@ -3,7 +3,7 @@ from itertools import tee, starmap
from collections import namedtuple
from zipline.gens.tradegens import SpecificEquityTrades
from zipline.gens.utils import roundrobin, hash_args
from zipline.gens.utils import roundrobin, hash_args, done_message
from zipline.gens.sort import date_sort
from zipline.gens.merge import merge
from zipline.gens.transform import StatefulTransform
@@ -34,8 +34,7 @@ def date_sorted_sources(*sources):
return date_sort(stream_in, names)
def merged_transforms(sorted_stream, bundles):
def merged_transforms(sorted_stream, *transforms):
"""
A generator that takes the expected output of a date_sort, pipes it
through a given set of transforms, and runs the results throught a
@@ -45,30 +44,34 @@ def merged_transforms(sorted_stream, bundles):
tnfm_kwargs should be a list of dictionaries representing keyword
arguments to each transform.
"""
for transform in transforms:
assert isinstance(transform, StatefulTransform)
# Generate expected hashes for each transform
namestrings = [bundle.tnfm.__name__ + hash_args(*bundle.args, **bundle.kwargs)
for bundle in bundles]
namestrings = [tnfm.get_hash() for tnfm in transforms]
# Create a copy of the stream for each transform.
split = tee(sorted_stream, len(bundles))
# Package a stream copy with each bundle
tnfms_with_streams = zip(split, bundles)
split = tee(sorted_stream, len(transforms))
# Package a stream copy with each StatefulTransform instance.
bundles = zip(transforms, split)
# Convert the copies into transform streams.
tnfm_gens = [
StatefulTransform(
stream_copy,
bundle.tnfm,
*bundle.args,
**bundle.kwargs
)
for stream_copy, bundle in tnfms_with_streams
]
tnfm_gens = [tnfm.transform(stream) for tnfm, stream in bundles]
# Roundrobin the outputs of our transforms to create a single flat stream.
# Roundrobin the outputs of our transforms to create a single flat
# stream.
to_merge = roundrobin(tnfm_gens, namestrings)
# Pipe the stream into merge.
merged = merge(to_merge, namestrings)
# Return the merged events.
return merged
def zipline(sources, transforms, endpoint):
assert isinstance(sources, (list, tuple))
assert isinstance(transforms, (list, tuple))
+12 -13
View File
@@ -21,10 +21,10 @@ if __name__ == "__main__":
#Set up source a. One minute between events.
args_a = tuple()
kwargs_a = {
'count' : 2000,
'count' : 325,
'sids' : [1,2,3],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 10),
'delta' : timedelta(hours = 6),
'filter' : filter
}
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
@@ -32,30 +32,29 @@ if __name__ == "__main__":
#Set up source b. Two minutes between events.
args_b = tuple()
kwargs_b = {
'count' : 2000,
'count' : 7500,
'sids' : [2,3,4],
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 10),
'delta' : timedelta(minutes = 5),
'filter' : filter
}
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
#Set up source c. Three minutes between events.
sort_out = date_sorted_sources(source_a, source_b)
sorted = date_sorted_sources(source_a, source_b)
passthrough = TransformBundle(Passthrough, (), {})
mavg_price = TransformBundle(MovingAverage, (timedelta(minutes = 20), ['price']), {})
tnfm_bundles = (passthrough, mavg_price)
passthrough = StatefulTransform(Passthrough)
mavg_price = StatefulTransform(MovingAverage, timedelta(minutes = 20), ['price'])
merge_out = merged_transforms(sort_out, tnfm_bundles)
merged = merged_transforms(sorted, passthrough, mavg_price)
algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
environment = create_trading_environment(year = 2012)
style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE
client_out = tsc(merge_out, algo, environment, style)
for message in client_out:
pp(message)
sleep(1)
trading_client = tsc(algo, environment, style)
for message in trading_client.simulate(merged):
pp(message)
+2 -1
View File
@@ -6,7 +6,7 @@ from collections import deque
from zipline import ndict
from zipline.gens.utils import hash_args, \
assert_merge_protocol
assert_merge_protocol, done_message
from itertools import repeat
def merge(stream_in, tnfm_ids):
@@ -51,6 +51,7 @@ def merge(stream_in, tnfm_ids):
assert len(queue) == 1, "Bad queue in merge on exit: %s" % queue
assert queue[0].dt == "DONE", \
"Bad last message in merge on exit: %s" % queue
yield done_message('Merge')
def merge_one(sources):
dict_primer = zip(sources.keys(), repeat(None))
+21 -24
View File
@@ -43,65 +43,60 @@ class TradeSimulationClient(object):
is sent to the algo.
"""
def __init__(self, stream_in, algo, environment, sim_style):
def __init__(self, algo, environment, sim_style):
self.stream_in = stream_in
self.algo = algo
self.sids = algo.get_sid_filter()
self.environment = environment
self.style = sim_style
self.__generator = None
def get_hash(self):
"""
There should only ever be one TSC in the system.
"""
return self.__class__.__name__ + hash_args()
def __iter__(self):
return self
def next(self):
if self.__generator:
return self.__generator.next()
else:
self.__generator = self.run_simulation()
return self.__generator.next()
def run_simulation(self):
def simulate(self, stream_in):
"""
Main generator work loop.
"""
# Simulate filling any open orders made by the previous run of
# the user's algorithm. Sets the txn field to true on any
# event that results in a filled order.
ordering_client = StatefulTransform(
self.stream_in,
TransactionSimulator,
self.sids,
style = self.style
)
with_filled_orders = ordering_client.transform(stream_in)
# Pipe the events with transactions to perf. This will remove
# the txn field added by TransactionSimulator and replace it
# with a portfolio object to be passed to the user's
# algorithm. Also adds a PERF_MESSAGE field which is usually
# none, but contains an update message once per day.
current_portfolio = StatefulTransform(
ordering_client,
perf_tracker = StatefulTransform(
PerformanceTracker,
self.environment,
self.sids
)
# Pass both the ordering client's state and messages with the
# current portfolio into the algorithm for simulation.
with_portfolio = perf_tracker.transform(with_filled_orders)
# Pass the messages from perf along with the trading client's
# state into the algorithm for simulation. We provide the
# trading client so that the algorithm can place new orders
# into the client's order book.
algo_results = AlgorithmSimulator(
current_portfolio,
with_portfolio,
ordering_client.state,
self.algo,
)
# The algorithm will yield a daily_results message (as
# calculated by the performance tracker) at the end of each
# day. It will also yield a risk report at the end of the
# simulation.
for message in algo_results:
yield message
@@ -120,7 +115,7 @@ class AlgorithmSimulator(object):
self.sids = algo.get_sid_filter()
# Monkey patch the user algorithm to place orders in the
# txn_sim order book.
# TransactionSimulator's order book.
self.algo.set_order(self.order)
self.algo.set_logger(logbook.Logger("Algolog"))
@@ -189,6 +184,8 @@ class AlgorithmSimulator(object):
if event.perf_message:
yield event.perf_message
del event['perf_message']
if event.dt == "DONE":
break
# This should only happen for the first event we run.
if self.simulation_dt == None:
+7 -17
View File
@@ -48,7 +48,7 @@ class StatefulTransform(object):
set to true will forward all fields in the original message.
Otherwise only dt, tnfm_id, and tnfm_value are forwarded.
"""
def __init__(self, stream_in, tnfm_class, *args, **kwargs):
def __init__(self, tnfm_class, *args, **kwargs):
assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \
"Stateful transform requires a class."
assert tnfm_class.__dict__.has_key('update'), \
@@ -56,36 +56,26 @@ class StatefulTransform(object):
self.forward_all = tnfm_class.__dict__.get('FORWARDER', False)
self.update_in_place = tnfm_class.__dict__.get('UPDATER', False)
# You can't be both a forwarded and an updater.
assert not all([self.forward_all, self.update_in_place])
self.stream_in = stream_in
# Create an instance of our transform class.
self.state = tnfm_class(*args, **kwargs)
# Create the string associated with this generator's output.
self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
# Generator isn't initialized until someone calls __iter__ or next().
self.__generator = None
def get_hash(self):
return self.namestring
def next(self):
if self.__generator:
return self.__generator.next()
else:
self.__generator = self._gen()
return self.__generator.next()
def transform(self, stream_in):
return self._gen(stream_in)
def __iter__(self):
return self
def _gen(self):
def _gen(self, stream_in):
# IMPORTANT: Messages may contain pointers that are shared with
# other streams, so we only manipulate copies.
for message in self.stream_in:
for message in stream_in:
assert_sort_unframe_protocol(message)
message_copy = deepcopy(message)
+1 -1
View File
@@ -23,7 +23,7 @@ def mock_done(id):
"source_id" : id,
'tnfm_id' : id,
'tnfm_value': None,
'type' : 0
'type' : DATASOURCE_TYPE.DONE
})
done_message = mock_done