mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 12:26:23 +08:00
tradesim as generator v. 2.0
This commit is contained in:
@@ -11,8 +11,8 @@ log = logbook.Logger('Transaction Simulator')
|
||||
class TransactionSimulator(object):
|
||||
UPDATER = True
|
||||
|
||||
def __init__(self, open_orders, style=SIMULATION_STYLE.PARTIAL_VOLUME):
|
||||
self.open_orders = open_orders
|
||||
def __init__(self, sid_filter, style=SIMULATION_STYLE.PARTIAL_VOLUME):
|
||||
self.open_orders = {}
|
||||
self.txn_count = 0
|
||||
self.trade_window = datetime.timedelta(seconds=30)
|
||||
self.orderTTL = datetime.timedelta(days=1)
|
||||
@@ -27,6 +27,12 @@ class TransactionSimulator(object):
|
||||
elif style == SIMULATION_STYLE.NOOP:
|
||||
self.apply_trade_to_open_orders = self.simulate_noop
|
||||
|
||||
for sid in sid_filter:
|
||||
self.open_orders[sid] = []
|
||||
|
||||
def place_order(self, order):
|
||||
self.open_orders[order.sid].append(order)
|
||||
|
||||
def update(self, event):
|
||||
event.TRANSACTION = None
|
||||
if event.type == zp.DATASOURCE_TYPE.TRADE:
|
||||
|
||||
@@ -15,7 +15,7 @@ def date_sorted_sources(*sources):
|
||||
"""
|
||||
Takes an iterable of SortBundles, generating namestrings and initialized datasources
|
||||
for each before piping them into a date_sort.
|
||||
n """
|
||||
"""
|
||||
|
||||
for source in sources:
|
||||
assert iter(source), "Source %s not iterable" % source
|
||||
@@ -55,7 +55,7 @@ def merged_transforms(sorted_stream, bundles):
|
||||
tnfms_with_streams = zip(split, bundles)
|
||||
|
||||
# Convert the copies into transform streams.
|
||||
tnfms = [
|
||||
tnfm_gens = [
|
||||
StatefulTransform(
|
||||
stream_copy,
|
||||
bundle.tnfm,
|
||||
@@ -64,8 +64,6 @@ def merged_transforms(sorted_stream, bundles):
|
||||
)
|
||||
for stream_copy, bundle in tnfms_with_streams
|
||||
]
|
||||
tnfm_gens = [tnfm.gen() for tnfm in tnfms]
|
||||
|
||||
|
||||
# Roundrobin the outputs of our transforms to create a single flat stream.
|
||||
to_merge = roundrobin(tnfm_gens, namestrings)
|
||||
|
||||
@@ -11,7 +11,7 @@ from zipline.gens.composites import SourceBundle, TransformBundle, \
|
||||
date_sorted_sources, merged_transforms
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.transform import MovingAverage, Passthrough, StatefulTransform
|
||||
from zipline.gens.tradesimulation import trade_simulation_client as tsc
|
||||
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
@@ -21,6 +21,7 @@ if __name__ == "__main__":
|
||||
#Set up source a. One minute between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'count' : 2000,
|
||||
'sids' : [1,2,3],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 10),
|
||||
@@ -31,6 +32,7 @@ if __name__ == "__main__":
|
||||
#Set up source b. Two minutes between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'count' : 2000,
|
||||
'sids' : [2,3,4],
|
||||
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 10),
|
||||
|
||||
@@ -84,7 +84,7 @@ class SpecificEquityTrades(object):
|
||||
self.generator = self.create_fresh_generator()
|
||||
|
||||
def __iter__(self):
|
||||
return self.generator
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
return self.generator.next()
|
||||
|
||||
+200
-124
@@ -9,7 +9,7 @@ from zipline.gens.transform import StatefulTransform
|
||||
from zipline.finance.trading import TransactionSimulator
|
||||
from zipline.finance.performance import PerformanceTracker
|
||||
|
||||
def trade_simulation_client(stream_in, algo, environment, sim_style):
|
||||
class TradeSimulationClient(object):
|
||||
"""
|
||||
Generator that takes the expected output of a merge, a user
|
||||
algorithm, a trading environment, and a simulator style as
|
||||
@@ -42,61 +42,123 @@ def trade_simulation_client(stream_in, algo, environment, sim_style):
|
||||
overwritten so that only the most recent snapshot of the universe
|
||||
is sent to the algo.
|
||||
"""
|
||||
|
||||
#============
|
||||
# Algo Setup
|
||||
#============
|
||||
|
||||
# Initialize txn_sim's dictionary of orders here so that we can
|
||||
# reference it from within the user's algorithm.
|
||||
|
||||
sids = algo.get_sid_filter()
|
||||
open_orders = {}
|
||||
def __init__(self, stream_in, algo, environment, sim_style):
|
||||
|
||||
for sid in sids:
|
||||
open_orders[sid] = []
|
||||
self.stream_in = stream_in
|
||||
self.algo = algo
|
||||
self.sids = algo.get_sid_filter()
|
||||
self.environment = environment
|
||||
self.style = sim_style
|
||||
|
||||
self.__generator = None
|
||||
|
||||
|
||||
def get_hash(self):
|
||||
"""
|
||||
There should only ever be one TSC in the system.
|
||||
"""
|
||||
return self.__class__.__name__ + hash_args()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
if self.__generator:
|
||||
return self.__generator.next()
|
||||
else:
|
||||
self.__generator = self.run_simulation()
|
||||
return self.__generator.next()
|
||||
|
||||
def run_simulation(self):
|
||||
"""
|
||||
Main generator work loop.
|
||||
"""
|
||||
# Simulate filling any open orders made by the previous run of
|
||||
# the user's algorithm. Sets the txn field to true on any
|
||||
# event that results in a filled order.
|
||||
ordering_client = StatefulTransform(
|
||||
self.stream_in,
|
||||
TransactionSimulator,
|
||||
self.sids,
|
||||
style = self.style
|
||||
)
|
||||
# Pipe the events with transactions to perf. This will remove
|
||||
# the txn field added by TransactionSimulator and replace it
|
||||
# with a portfolio object to be passed to the user's
|
||||
# algorithm. Also adds a PERF_MESSAGE field which is usually
|
||||
# none, but contains an update message once per day.
|
||||
current_portfolio = StatefulTransform(
|
||||
ordering_client,
|
||||
PerformanceTracker,
|
||||
self.environment,
|
||||
self.sids
|
||||
)
|
||||
# Pass both the ordering client's state and messages with the
|
||||
# current portfolio into the algorithm for simulation.
|
||||
algo_results = AlgorithmSimulator(
|
||||
current_portfolio,
|
||||
ordering_client.state,
|
||||
self.algo,
|
||||
)
|
||||
|
||||
for message in algo_results:
|
||||
yield message
|
||||
|
||||
|
||||
class AlgorithmSimulator(object):
|
||||
|
||||
# Pipe the in stream into the transaction simulator.
|
||||
# Creates a txn field on the event containing transaction
|
||||
# information if we filled any pending orders on the event's sid.
|
||||
# TRANSACTION is None if we didn't fill any orders.
|
||||
with_txns = StatefulTransform(
|
||||
stream_in,
|
||||
TransactionSimulator,
|
||||
open_orders,
|
||||
style = sim_style
|
||||
)
|
||||
|
||||
# Pipe the events with transactions to perf. This will remove the
|
||||
# txn field added by TransactionSimulator and replace it with
|
||||
# a portfolio object to be passed to the user's algorithm. Also adds
|
||||
# a PERF_MESSAGE field which is usually none, but contains an update
|
||||
# message once per day.
|
||||
with_portfolio_and_perf_msg = StatefulTransform(
|
||||
with_txns,
|
||||
PerformanceTracker,
|
||||
environment,
|
||||
sids
|
||||
)
|
||||
|
||||
# Batch the event stream by dt to be processed by the user's algo.
|
||||
# Yields perf messages whenever it encounters them.
|
||||
perf_messages = algo_simulator(with_portfolio_and_perf_msg, sids, algo, open_orders)
|
||||
|
||||
for message in perf_messages:
|
||||
yield message
|
||||
|
||||
|
||||
def algo_simulator(stream_in, sids, algo, order_book):
|
||||
def __init__(self, stream_in, order_book, algo):
|
||||
|
||||
simulation_dt = None
|
||||
self.stream_in = stream_in
|
||||
|
||||
# Closure to pass into the user's algo to allow placing orders
|
||||
# into the txn_sim's dict of open orders.
|
||||
def order(sid, amount):
|
||||
assert sid in sids, "Order on invalid sid: %i" % sid
|
||||
# We extract the order book from the txn client so that
|
||||
# the algo can place new orders.
|
||||
self.order_book = order_book
|
||||
|
||||
self.algo = algo
|
||||
self.sids = algo.get_sid_filter()
|
||||
|
||||
# Monkey patch the user algorithm to place orders in the
|
||||
# txn_sim order book.
|
||||
self.algo.set_order(self.order)
|
||||
self.algo.set_logger(logbook.Logger("Algolog"))
|
||||
|
||||
# Call the user's initialize method.
|
||||
self.algo.initialize()
|
||||
|
||||
# The algorithm's universe as of our most recent event.
|
||||
self.universe = ndict()
|
||||
|
||||
for sid in self.sids:
|
||||
self.universe[sid] = ndict()
|
||||
self.universe.portfolio = None
|
||||
|
||||
# We don't have a datetime for the current snapshot until we
|
||||
# receive a message.
|
||||
self.simulation_dt = None
|
||||
self.this_snapshot_dt = None
|
||||
|
||||
self.__generator = None
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
if self.__generator:
|
||||
return self.__generator.next()
|
||||
else:
|
||||
self.__generator = self._gen()
|
||||
return self.__generator.next()
|
||||
|
||||
def order(self, sid, amount):
|
||||
"""
|
||||
Closure to pass into the user's algo to allow placing orders
|
||||
into the txn_sim's dict of open orders.
|
||||
"""
|
||||
assert sid in self.sids, "Order on invalid sid: %i" % sid
|
||||
order = ndict({
|
||||
'dt' : simulation_dt,
|
||||
'dt' : self.simulation_dt,
|
||||
'sid' : sid,
|
||||
'amount' : int(amount),
|
||||
'filled' : 0
|
||||
@@ -104,91 +166,105 @@ def algo_simulator(stream_in, sids, algo, order_book):
|
||||
|
||||
# Tell the user if they try to buy 0 shares of something.
|
||||
if order.amount == 0:
|
||||
log = "requested to trade zero shares of {sid}".format(
|
||||
zero_message = "Requested to trade zero shares of {sid}".format(
|
||||
sid=event.sid
|
||||
)
|
||||
log.debug(log)
|
||||
log.debug(zero_message)
|
||||
# Don't bother placing orders for 0 shares.
|
||||
return
|
||||
|
||||
order_book[sid].append(order)
|
||||
|
||||
# Set the algo's order method.
|
||||
algo.set_order(order)
|
||||
|
||||
# Provide a logbook logging interface to user code.
|
||||
algo.set_logger(logbook.Logger("Algolog"))
|
||||
# Add non-zero orders to the order book.
|
||||
# !!!IMPORTANT SIDE-EFFECT!!!
|
||||
# This modifies the internal state of the transaction
|
||||
# simulator so that it can fill the placed order when it
|
||||
# receives its next message.
|
||||
self.order_book.place_order(order)
|
||||
|
||||
# Call user-defined initialize method before we process any
|
||||
# events.
|
||||
algo.initialize()
|
||||
|
||||
universe = ndict()
|
||||
for sid in sids:
|
||||
universe[sid] = ndict()
|
||||
universe.portfolio = None
|
||||
this_snapshot_dt = None
|
||||
|
||||
for event in stream_in:
|
||||
# Yield any perf messages received to be relayed back to the browser.
|
||||
if event.perf_message:
|
||||
yield event.perf_message
|
||||
del event['perf_message']
|
||||
|
||||
# This should only happen for the first event we run.
|
||||
if simulation_dt == None:
|
||||
simulation_dt = event.dt
|
||||
|
||||
# If we are currently creating a new message and this update
|
||||
# matches the message dt, update the state of the universe.
|
||||
|
||||
if this_snapshot_dt != None:
|
||||
|
||||
if event.dt == this_snapshot_dt:
|
||||
update_universe(event, universe)
|
||||
|
||||
# If we are constructing a snapshot and we hit a new dt, call
|
||||
# handle_data and record how long it takes.
|
||||
else:
|
||||
start_tic = datetime.now()
|
||||
algo.handle_data(universe)
|
||||
stop_tic = datetime.now()
|
||||
|
||||
# How long did you take?
|
||||
delta = stop_tic - start_tic
|
||||
|
||||
# Update the simulation time.
|
||||
simulation_dt = this_snapshot_dt + delta
|
||||
def _gen(self):
|
||||
"""
|
||||
Internal generator work loop.
|
||||
"""
|
||||
for event in self.stream_in:
|
||||
# Yield any perf messages received to be relayed back to the browser.
|
||||
if event.perf_message:
|
||||
yield event.perf_message
|
||||
del event['perf_message']
|
||||
|
||||
# Update the universe with the new event.
|
||||
update_universe(event, universe)
|
||||
# This should only happen for the first event we run.
|
||||
if self.simulation_dt == None:
|
||||
self.simulation_dt = event.dt
|
||||
|
||||
# ======================
|
||||
# Time Compression Logic
|
||||
# ======================
|
||||
|
||||
if self.this_snapshot_dt != None:
|
||||
self.update_current_snapshot(event)
|
||||
|
||||
# If the current event is later than the simulation
|
||||
# time, update the universe and start constructing
|
||||
# another snapshot.
|
||||
if event.dt >= simulation_dt:
|
||||
this_snapshot_dt = event.dt
|
||||
else:
|
||||
this_snapshot_dt = None
|
||||
# We have been fastforwarding. Update the universe
|
||||
# and check if we can start a new snapshot.
|
||||
# The algorithm has been missing events because it took
|
||||
# too long processing. Update the universe with data from
|
||||
# this event, then check if enough time has passed that we
|
||||
# can start a new snapshot.
|
||||
else:
|
||||
self.update_universe(event)
|
||||
if event.dt >= self.simulation_dt:
|
||||
self.this_snapshot_dt = event.dt
|
||||
|
||||
def update_current_snapshot(self, event):
|
||||
"""
|
||||
Update our current snapshot of the universe. Call handle_data if
|
||||
"""
|
||||
# The new event matches our snapshot dt. Just update the
|
||||
# universe and move on.
|
||||
if event.dt == self.this_snapshot_dt:
|
||||
self.update_universe(event)
|
||||
|
||||
# The new event does not match our snapshot.
|
||||
else:
|
||||
update_universe(event, universe)
|
||||
if event.dt >= simulation_dt:
|
||||
this_snapshot_dt = event.dt
|
||||
|
||||
self.simulate_current_snapshot()
|
||||
|
||||
# Once we've finished simulating the old snapshot,
|
||||
# we can update the universe with the new event.
|
||||
self.update_universe(event)
|
||||
|
||||
# The current event is later than the simulation time,
|
||||
# which means the algorithm finished quickly enough to
|
||||
# receive the new event. Start a new snapshot with this
|
||||
# event's dt.
|
||||
if event.dt >= self.simulation_dt:
|
||||
self.this_snapshot_dt = event.dt
|
||||
|
||||
|
||||
# The algorithm spent enough time processing that it
|
||||
# missed the new event. Wait to start a new snapshot until
|
||||
# the events catch up to the algo's simulated dt.
|
||||
else:
|
||||
self.this_snapshot_dt = None
|
||||
|
||||
def simulate_current_snapshot(self):
|
||||
"""
|
||||
Run the user's algo against our current snapshot and update the algo's
|
||||
simulated time.
|
||||
"""
|
||||
start_tic = datetime.now()
|
||||
self.algo.handle_data(self.universe)
|
||||
stop_tic = datetime.now()
|
||||
|
||||
# How long did you take?
|
||||
delta = stop_tic - start_tic
|
||||
|
||||
# Update the simulation time.
|
||||
self.simulation_dt = self.this_snapshot_dt + delta
|
||||
|
||||
def update_universe(event, universe):
|
||||
|
||||
universe.portfolio = event.portfolio
|
||||
del event['portfolio']
|
||||
def update_universe(self, event):
|
||||
"""
|
||||
Update the universe with new event information.
|
||||
"""
|
||||
# Update our portfolio.
|
||||
self.universe.portfolio = event.portfolio
|
||||
|
||||
event_sid = event.sid
|
||||
del event['sid']
|
||||
|
||||
for field in event.keys():
|
||||
universe[event_sid][field] = event[field]
|
||||
# Update our knowledge of this event's sid
|
||||
for field in event.keys():
|
||||
self.universe[event.sid][field] = event[field]
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ def functional_transform(stream_in, func, *args, **kwargs):
|
||||
class StatefulTransform(object):
|
||||
"""
|
||||
Generic transform generator that takes each message from an
|
||||
in-stream and passes it to a state class. For each call to
|
||||
in-stream and passes it to a state object. For each call to
|
||||
update, the state class must produce a message to be fed
|
||||
downstream. Any transform class with the FORWARDER class variable
|
||||
set to true will forward all fields in the original message.
|
||||
@@ -63,16 +63,26 @@ class StatefulTransform(object):
|
||||
# Create an instance of our transform class.
|
||||
self.state = tnfm_class(*args, **kwargs)
|
||||
|
||||
# Generate the string associated with this generator's output.
|
||||
# Create the string associated with this generator's output.
|
||||
self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
|
||||
|
||||
# Generator isn't initialized until someone calls __iter__ or next().
|
||||
self.__generator = None
|
||||
|
||||
def get_hash(self):
|
||||
return self.namestring
|
||||
|
||||
def next(self):
|
||||
if self.__generator:
|
||||
return self.__generator.next()
|
||||
else:
|
||||
self.__generator = self._gen()
|
||||
return self.__generator.next()
|
||||
|
||||
def __iter__(self):
|
||||
return self.gen()
|
||||
|
||||
def gen(self):
|
||||
return self
|
||||
|
||||
def _gen(self):
|
||||
# IMPORTANT: Messages may contain pointers that are shared with
|
||||
# other streams, so we only manipulate copies.
|
||||
for message in self.stream_in:
|
||||
|
||||
Reference in New Issue
Block a user