mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 10:14:44 +08:00
generator-style perf now sends a risk report on receipt of DONE
This commit is contained in:
@@ -114,7 +114,6 @@ class Component(object):
|
||||
# Core Methods
|
||||
# ------------
|
||||
|
||||
|
||||
def loop_send(self):
|
||||
"""
|
||||
The main component loop. This is wrapped inside a
|
||||
|
||||
@@ -203,10 +203,15 @@ class PerformanceTracker(object):
|
||||
self.todays_performance.positions[sid] = Position(sid)
|
||||
|
||||
def update(self, event):
|
||||
event.perf_message = self.process_event(event)
|
||||
event.portfolio = self.get_portfolio()
|
||||
del event['TRANSACTION']
|
||||
return event
|
||||
if event.dt == "DONE":
|
||||
event.perf_message = self.handle_simulation_end()
|
||||
del event['TRANSACTION']
|
||||
return event
|
||||
else:
|
||||
event.perf_message = self.process_event(event)
|
||||
event.portfolio = self.get_portfolio()
|
||||
del event['TRANSACTION']
|
||||
return event
|
||||
|
||||
def get_portfolio(self):
|
||||
return self.cumulative_performance.as_portfolio()
|
||||
@@ -270,6 +275,7 @@ class PerformanceTracker(object):
|
||||
#calculate performance as of last trade
|
||||
self.cumulative_performance.calculate_performance()
|
||||
self.todays_performance.calculate_performance()
|
||||
|
||||
|
||||
return message
|
||||
|
||||
@@ -296,7 +302,8 @@ class PerformanceTracker(object):
|
||||
# calculate progress of test
|
||||
self.progress = self.day_count / self.total_days
|
||||
|
||||
#TODO TODO TODO!!
|
||||
# Take a snapshot of our current peformance to return to the
|
||||
# browser.
|
||||
daily_update = self.to_dict()
|
||||
|
||||
if self.trading_environment.max_drawdown:
|
||||
@@ -356,12 +363,8 @@ class PerformanceTracker(object):
|
||||
exceeded_max_loss = self.exceeded_max_loss
|
||||
)
|
||||
|
||||
if self.results_socket:
|
||||
log.info("about to stream the risk report...")
|
||||
risk_dict = self.risk_report.to_dict()
|
||||
|
||||
msg = zp.RISK_FRAME(risk_dict)
|
||||
self.results_socket.send(msg)
|
||||
risk_dict = self.risk_report.to_dict()
|
||||
return risk_dict
|
||||
|
||||
class Position(object):
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ class TransactionSimulator(object):
|
||||
|
||||
def update(self, event):
|
||||
event.TRANSACTION = None
|
||||
# We only fill transactions on trade events.
|
||||
if event.type == zp.DATASOURCE_TYPE.TRADE:
|
||||
event.TRANSACTION = self.apply_trade_to_open_orders(event)
|
||||
return event
|
||||
|
||||
+21
-18
@@ -3,7 +3,7 @@ from itertools import tee, starmap
|
||||
from collections import namedtuple
|
||||
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.utils import roundrobin, hash_args
|
||||
from zipline.gens.utils import roundrobin, hash_args, done_message
|
||||
from zipline.gens.sort import date_sort
|
||||
from zipline.gens.merge import merge
|
||||
from zipline.gens.transform import StatefulTransform
|
||||
@@ -34,8 +34,7 @@ def date_sorted_sources(*sources):
|
||||
|
||||
return date_sort(stream_in, names)
|
||||
|
||||
|
||||
def merged_transforms(sorted_stream, bundles):
|
||||
def merged_transforms(sorted_stream, *transforms):
|
||||
"""
|
||||
A generator that takes the expected output of a date_sort, pipes it
|
||||
through a given set of transforms, and runs the results throught a
|
||||
@@ -45,30 +44,34 @@ def merged_transforms(sorted_stream, bundles):
|
||||
tnfm_kwargs should be a list of dictionaries representing keyword
|
||||
arguments to each transform.
|
||||
"""
|
||||
for transform in transforms:
|
||||
assert isinstance(transform, StatefulTransform)
|
||||
|
||||
# Generate expected hashes for each transform
|
||||
namestrings = [bundle.tnfm.__name__ + hash_args(*bundle.args, **bundle.kwargs)
|
||||
for bundle in bundles]
|
||||
namestrings = [tnfm.get_hash() for tnfm in transforms]
|
||||
|
||||
# Create a copy of the stream for each transform.
|
||||
split = tee(sorted_stream, len(bundles))
|
||||
# Package a stream copy with each bundle
|
||||
tnfms_with_streams = zip(split, bundles)
|
||||
split = tee(sorted_stream, len(transforms))
|
||||
|
||||
# Package a stream copy with each StatefulTransform instance.
|
||||
bundles = zip(transforms, split)
|
||||
|
||||
# Convert the copies into transform streams.
|
||||
tnfm_gens = [
|
||||
StatefulTransform(
|
||||
stream_copy,
|
||||
bundle.tnfm,
|
||||
*bundle.args,
|
||||
**bundle.kwargs
|
||||
)
|
||||
for stream_copy, bundle in tnfms_with_streams
|
||||
]
|
||||
tnfm_gens = [tnfm.transform(stream) for tnfm, stream in bundles]
|
||||
|
||||
# Roundrobin the outputs of our transforms to create a single flat stream.
|
||||
# Roundrobin the outputs of our transforms to create a single flat
|
||||
# stream.
|
||||
to_merge = roundrobin(tnfm_gens, namestrings)
|
||||
|
||||
# Pipe the stream into merge.
|
||||
merged = merge(to_merge, namestrings)
|
||||
# Return the merged events.
|
||||
return merged
|
||||
|
||||
def zipline(sources, transforms, endpoint):
|
||||
assert isinstance(sources, (list, tuple))
|
||||
assert isinstance(transforms, (list, tuple))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
+12
-13
@@ -21,10 +21,10 @@ if __name__ == "__main__":
|
||||
#Set up source a. One minute between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'count' : 2000,
|
||||
'count' : 325,
|
||||
'sids' : [1,2,3],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 10),
|
||||
'delta' : timedelta(hours = 6),
|
||||
'filter' : filter
|
||||
}
|
||||
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
@@ -32,30 +32,29 @@ if __name__ == "__main__":
|
||||
#Set up source b. Two minutes between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'count' : 2000,
|
||||
'count' : 7500,
|
||||
'sids' : [2,3,4],
|
||||
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 10),
|
||||
'delta' : timedelta(minutes = 5),
|
||||
'filter' : filter
|
||||
}
|
||||
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
|
||||
#Set up source c. Three minutes between events.
|
||||
|
||||
sort_out = date_sorted_sources(source_a, source_b)
|
||||
sorted = date_sorted_sources(source_a, source_b)
|
||||
|
||||
passthrough = TransformBundle(Passthrough, (), {})
|
||||
mavg_price = TransformBundle(MovingAverage, (timedelta(minutes = 20), ['price']), {})
|
||||
tnfm_bundles = (passthrough, mavg_price)
|
||||
passthrough = StatefulTransform(Passthrough)
|
||||
mavg_price = StatefulTransform(MovingAverage, timedelta(minutes = 20), ['price'])
|
||||
|
||||
merge_out = merged_transforms(sort_out, tnfm_bundles)
|
||||
merged = merged_transforms(sorted, passthrough, mavg_price)
|
||||
|
||||
algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
|
||||
environment = create_trading_environment(year = 2012)
|
||||
style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
|
||||
client_out = tsc(merge_out, algo, environment, style)
|
||||
for message in client_out:
|
||||
pp(message)
|
||||
sleep(1)
|
||||
trading_client = tsc(algo, environment, style)
|
||||
|
||||
for message in trading_client.simulate(merged):
|
||||
pp(message)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from collections import deque
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import hash_args, \
|
||||
assert_merge_protocol
|
||||
assert_merge_protocol, done_message
|
||||
from itertools import repeat
|
||||
|
||||
def merge(stream_in, tnfm_ids):
|
||||
@@ -51,6 +51,7 @@ def merge(stream_in, tnfm_ids):
|
||||
assert len(queue) == 1, "Bad queue in merge on exit: %s" % queue
|
||||
assert queue[0].dt == "DONE", \
|
||||
"Bad last message in merge on exit: %s" % queue
|
||||
yield done_message('Merge')
|
||||
|
||||
def merge_one(sources):
|
||||
dict_primer = zip(sources.keys(), repeat(None))
|
||||
|
||||
@@ -43,65 +43,60 @@ class TradeSimulationClient(object):
|
||||
is sent to the algo.
|
||||
"""
|
||||
|
||||
def __init__(self, stream_in, algo, environment, sim_style):
|
||||
def __init__(self, algo, environment, sim_style):
|
||||
|
||||
self.stream_in = stream_in
|
||||
self.algo = algo
|
||||
self.sids = algo.get_sid_filter()
|
||||
self.environment = environment
|
||||
self.style = sim_style
|
||||
|
||||
self.__generator = None
|
||||
|
||||
|
||||
def get_hash(self):
|
||||
"""
|
||||
There should only ever be one TSC in the system.
|
||||
"""
|
||||
return self.__class__.__name__ + hash_args()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
if self.__generator:
|
||||
return self.__generator.next()
|
||||
else:
|
||||
self.__generator = self.run_simulation()
|
||||
return self.__generator.next()
|
||||
|
||||
def run_simulation(self):
|
||||
def simulate(self, stream_in):
|
||||
"""
|
||||
Main generator work loop.
|
||||
"""
|
||||
|
||||
# Simulate filling any open orders made by the previous run of
|
||||
# the user's algorithm. Sets the txn field to true on any
|
||||
# event that results in a filled order.
|
||||
ordering_client = StatefulTransform(
|
||||
self.stream_in,
|
||||
TransactionSimulator,
|
||||
self.sids,
|
||||
style = self.style
|
||||
)
|
||||
with_filled_orders = ordering_client.transform(stream_in)
|
||||
|
||||
# Pipe the events with transactions to perf. This will remove
|
||||
# the txn field added by TransactionSimulator and replace it
|
||||
# with a portfolio object to be passed to the user's
|
||||
# algorithm. Also adds a PERF_MESSAGE field which is usually
|
||||
# none, but contains an update message once per day.
|
||||
current_portfolio = StatefulTransform(
|
||||
ordering_client,
|
||||
perf_tracker = StatefulTransform(
|
||||
PerformanceTracker,
|
||||
self.environment,
|
||||
self.sids
|
||||
)
|
||||
# Pass both the ordering client's state and messages with the
|
||||
# current portfolio into the algorithm for simulation.
|
||||
with_portfolio = perf_tracker.transform(with_filled_orders)
|
||||
|
||||
# Pass the messages from perf along with the trading client's
|
||||
# state into the algorithm for simulation. We provide the
|
||||
# trading client so that the algorithm can place new orders
|
||||
# into the client's order book.
|
||||
algo_results = AlgorithmSimulator(
|
||||
current_portfolio,
|
||||
with_portfolio,
|
||||
ordering_client.state,
|
||||
self.algo,
|
||||
)
|
||||
|
||||
|
||||
# The algorithm will yield a daily_results message (as
|
||||
# calculated by the performance tracker) at the end of each
|
||||
# day. It will also yield a risk report at the end of the
|
||||
# simulation.
|
||||
for message in algo_results:
|
||||
yield message
|
||||
|
||||
@@ -120,7 +115,7 @@ class AlgorithmSimulator(object):
|
||||
self.sids = algo.get_sid_filter()
|
||||
|
||||
# Monkey patch the user algorithm to place orders in the
|
||||
# txn_sim order book.
|
||||
# TransactionSimulator's order book.
|
||||
self.algo.set_order(self.order)
|
||||
self.algo.set_logger(logbook.Logger("Algolog"))
|
||||
|
||||
@@ -189,6 +184,8 @@ class AlgorithmSimulator(object):
|
||||
if event.perf_message:
|
||||
yield event.perf_message
|
||||
del event['perf_message']
|
||||
if event.dt == "DONE":
|
||||
break
|
||||
|
||||
# This should only happen for the first event we run.
|
||||
if self.simulation_dt == None:
|
||||
|
||||
@@ -48,7 +48,7 @@ class StatefulTransform(object):
|
||||
set to true will forward all fields in the original message.
|
||||
Otherwise only dt, tnfm_id, and tnfm_value are forwarded.
|
||||
"""
|
||||
def __init__(self, stream_in, tnfm_class, *args, **kwargs):
|
||||
def __init__(self, tnfm_class, *args, **kwargs):
|
||||
assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \
|
||||
"Stateful transform requires a class."
|
||||
assert tnfm_class.__dict__.has_key('update'), \
|
||||
@@ -56,36 +56,26 @@ class StatefulTransform(object):
|
||||
|
||||
self.forward_all = tnfm_class.__dict__.get('FORWARDER', False)
|
||||
self.update_in_place = tnfm_class.__dict__.get('UPDATER', False)
|
||||
|
||||
# You can't be both a forwarded and an updater.
|
||||
assert not all([self.forward_all, self.update_in_place])
|
||||
|
||||
self.stream_in = stream_in
|
||||
|
||||
# Create an instance of our transform class.
|
||||
self.state = tnfm_class(*args, **kwargs)
|
||||
|
||||
# Create the string associated with this generator's output.
|
||||
self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
|
||||
|
||||
# Generator isn't initialized until someone calls __iter__ or next().
|
||||
self.__generator = None
|
||||
|
||||
def get_hash(self):
|
||||
return self.namestring
|
||||
|
||||
def next(self):
|
||||
if self.__generator:
|
||||
return self.__generator.next()
|
||||
else:
|
||||
self.__generator = self._gen()
|
||||
return self.__generator.next()
|
||||
def transform(self, stream_in):
|
||||
return self._gen(stream_in)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _gen(self):
|
||||
def _gen(self, stream_in):
|
||||
# IMPORTANT: Messages may contain pointers that are shared with
|
||||
# other streams, so we only manipulate copies.
|
||||
for message in self.stream_in:
|
||||
for message in stream_in:
|
||||
|
||||
assert_sort_unframe_protocol(message)
|
||||
message_copy = deepcopy(message)
|
||||
|
||||
@@ -23,7 +23,7 @@ def mock_done(id):
|
||||
"source_id" : id,
|
||||
'tnfm_id' : id,
|
||||
'tnfm_value': None,
|
||||
'type' : 0
|
||||
'type' : DATASOURCE_TYPE.DONE
|
||||
})
|
||||
|
||||
done_message = mock_done
|
||||
|
||||
Reference in New Issue
Block a user