mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 20:06:18 +08:00
@@ -8,11 +8,16 @@ from zipline.core.component import Component
|
||||
from zipline.finance.trading import TransactionSimulator
|
||||
from zipline.utils.protocol_utils import ndict
|
||||
|
||||
from zipline.utils.log_utils import ZeroMQLogHandler, stdout_only_pipe
|
||||
|
||||
from logbook import Logger, NestedSetup, Processor, queues
|
||||
|
||||
log = logbook.Logger('TradeSimulation')
|
||||
|
||||
|
||||
class TradeSimulationClient(Component):
|
||||
|
||||
def init(self, trading_environment, sim_style):
|
||||
def init(self, trading_environment, sim_style, log_socket):
|
||||
self.received_count = 0
|
||||
self.prev_dt = None
|
||||
self.event_queue = None
|
||||
@@ -28,12 +33,15 @@ class TradeSimulationClient(Component):
|
||||
|
||||
self.event_data = ndict()
|
||||
self.perf = perf.PerformanceTracker(self.trading_environment)
|
||||
|
||||
|
||||
self.log_socket = log_socket
|
||||
|
||||
@property
|
||||
def get_id(self):
|
||||
return str(zp.FINANCE_COMPONENT.TRADING_CLIENT)
|
||||
|
||||
def set_algorithm(self, algorithm):
|
||||
|
||||
"""
|
||||
:param algorithm: must implement the algorithm protocol. See
|
||||
:py:mod:`zipline.test.algorithm`
|
||||
@@ -41,17 +49,41 @@ class TradeSimulationClient(Component):
|
||||
self.algorithm = algorithm
|
||||
# register the trading_client's order method with the algorithm
|
||||
self.algorithm.set_order(self.order)
|
||||
# ask the algorithm to initialize
|
||||
self.algorithm.initialize()
|
||||
|
||||
#TODO: re-enable initialization logging. This means we can't call set_algorithm
|
||||
#until we have a context for this component. Possibly this could happen
|
||||
# ask the algorithm to initialize, routing stdout to a zmq PUSH socket.
|
||||
|
||||
#with self.zmq_out.threadbound(), self.stdout_capture(self.logger, 'Algo print capture'):
|
||||
# self.algorithm.initialize()
|
||||
#if we don't have a log socket, initialize anyway.
|
||||
#else:
|
||||
# self.algorithm.initialize()
|
||||
|
||||
self.algorithm.initialize()
|
||||
|
||||
def open(self):
|
||||
|
||||
self.result_feed = self.connect_result()
|
||||
self.perf.open(self.context)
|
||||
|
||||
#If we have a log socket,setup context manager for exporting captured
|
||||
#print statements
|
||||
if self.log_socket:
|
||||
self.zmq_out = ZeroMQLogHandler(uri = self.log_socket, context = self.context)
|
||||
self.logger = Logger("Print")
|
||||
self.stdout_capture = stdout_only_pipe #THIS IS A CLASS!
|
||||
|
||||
#Initialize log capture for testing purposes.
|
||||
def setup_logging(self, context):
|
||||
if self.log_socket:
|
||||
self.zmq_out = ZeroMQLogHandler(uri = self.log_socket, context = context)
|
||||
self.logger = Logger("Print")
|
||||
self.stdout_capture = stdout_only_pipe #THIS IS A CLASS!
|
||||
|
||||
def do_work(self):
|
||||
# poll all the sockets
|
||||
socks = dict(self.poll.poll(self.heartbeat_timeout))
|
||||
|
||||
# see if the poller has results for the result_feed
|
||||
if socks.get(self.result_feed) == self.zmq.POLLIN:
|
||||
|
||||
@@ -69,6 +101,7 @@ class TradeSimulationClient(Component):
|
||||
event = zp.MERGE_UNFRAME(msg)
|
||||
self.received_count += 1
|
||||
# update performance and relay the event to the algorithm
|
||||
|
||||
self.process_event(event)
|
||||
if self.perf.exceeded_max_loss:
|
||||
self.finish_simulation()
|
||||
@@ -82,9 +115,8 @@ class TradeSimulationClient(Component):
|
||||
# signal Simulator, our ComponentHost, that this component is
|
||||
# done and Simulator needn't block exit on this component.
|
||||
self.signal_done()
|
||||
|
||||
|
||||
def process_event(self, event):
|
||||
|
||||
# generate transactions, if applicable
|
||||
txn = self.txn_sim.apply_trade_to_open_orders(event)
|
||||
if txn:
|
||||
@@ -109,10 +141,10 @@ class TradeSimulationClient(Component):
|
||||
# queue the event.
|
||||
self.queue_event(event)
|
||||
|
||||
|
||||
# if the event is later than our current time, run the algo
|
||||
# otherwise, the algorithm has fallen behind the feed
|
||||
# and processing per event is longer than time between events.
|
||||
|
||||
if event.dt >= self.current_dt:
|
||||
# compress time by moving the current_time up to the event
|
||||
# time.
|
||||
@@ -124,7 +156,6 @@ class TradeSimulationClient(Component):
|
||||
# move the algorithm's clock forward to include iteration time
|
||||
self.current_dt = self.current_dt + self.last_iteration_dur
|
||||
|
||||
|
||||
def run_algorithm(self):
|
||||
"""
|
||||
As per the algorithm protocol:
|
||||
@@ -135,8 +166,44 @@ class TradeSimulationClient(Component):
|
||||
current_portfolio = self.perf.get_portfolio()
|
||||
self.algorithm.set_portfolio(current_portfolio)
|
||||
data = self.get_data()
|
||||
|
||||
if len(data) > 0:
|
||||
self.algorithm.handle_data(data)
|
||||
|
||||
# data injection pipeline for log rerouting
|
||||
# any fields injected here should be added to
|
||||
# LOG_EXTRA_FIELDS in zipline/protocol.py
|
||||
|
||||
if self.log_socket:
|
||||
|
||||
def inject_event_data(record):
|
||||
|
||||
#Record the simulation time.
|
||||
|
||||
record.extra['algo_dt'] = self.current_dt
|
||||
|
||||
data_injector = Processor(inject_event_data)
|
||||
log_pipeline = NestedSetup([self.zmq_out,
|
||||
#e.g. FileHandler(...)
|
||||
data_injector])
|
||||
with log_pipeline.threadbound(), self.stdout_capture(self.logger, ''):
|
||||
self.algorithm.handle_data(data)
|
||||
# if no log socket, just run the algo normally
|
||||
else:
|
||||
self.algorithm.handle_data(data)
|
||||
|
||||
#Testing utility for log capture.
|
||||
def test_run_algorithm(self):
|
||||
|
||||
def inject_event_data(record):
|
||||
record.extra['algo_dt'] = datetime.datetime.utcnow() #Mock an event.dt
|
||||
|
||||
data_injector = Processor(inject_event_data)
|
||||
log_pipeline = NestedSetup([self.zmq_out,
|
||||
#e.g. FileHandler(...)
|
||||
data_injector])
|
||||
with log_pipeline.threadbound(), self.stdout_capture(self.logger, ''):
|
||||
self.algorithm.handle_data('data')
|
||||
# if no log socket, just run the algo normally
|
||||
|
||||
def connect_order(self):
|
||||
return self.connect_push_socket(self.addresses['order_address'])
|
||||
|
||||
@@ -213,7 +213,7 @@ class Component(object):
|
||||
Run the component.
|
||||
|
||||
Optionally takes an argument to catch and log all exceptions
|
||||
raised during execution ues this with care since it makes it
|
||||
raised during execution. Use this with care since it makes it
|
||||
very hard to debug since it mucks up your stacktraces.
|
||||
"""
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Simulator hosts all the components necessary to execute a simluation.
|
||||
Simulator hosts all the components necessary to execute a simulation.
|
||||
See :py:method""
|
||||
"""
|
||||
|
||||
|
||||
+12
-12
@@ -210,7 +210,7 @@ class Controller(object):
|
||||
|
||||
assert self.route_socket
|
||||
assert self.pub_socket
|
||||
assert self.cancel_socket
|
||||
#assert self.cancel_socket
|
||||
|
||||
# -- Publish --
|
||||
# =============
|
||||
@@ -219,9 +219,9 @@ class Controller(object):
|
||||
|
||||
# -- Cancel --
|
||||
# =============
|
||||
assert isinstance(self.cancel_socket,basestring), self.cancel_socket
|
||||
self.cancel = self.context.socket(self.zmq.REP)
|
||||
self.cancel.connect(self.cancel_socket)
|
||||
#assert isinstance(self.cancel_socket,basestring), self.cancel_socket
|
||||
#self.cancel = self.context.socket(self.zmq.REP)
|
||||
#self.cancel.connect(self.cancel_socket)
|
||||
|
||||
# -- Router --
|
||||
# =============
|
||||
@@ -231,9 +231,9 @@ class Controller(object):
|
||||
|
||||
poller = self.zmq.Poller()
|
||||
poller.register(self.router, self.zmq.POLLIN)
|
||||
poller.register(self.cancel, self.zmq.POLLIN)
|
||||
#poller.register(self.cancel, self.zmq.POLLIN)
|
||||
|
||||
self.associated += [self.pub, self.router, self.cancel]
|
||||
self.associated += [self.pub, self.router]# self.cancel]
|
||||
|
||||
# TODO: actually do this
|
||||
self.state = CONTROL_STATES.SOURCES_READY
|
||||
@@ -270,12 +270,12 @@ class Controller(object):
|
||||
log.error('Invalid frame', rawmessage)
|
||||
pass
|
||||
|
||||
if socks.get(self.cancel) == self.zmq.POLLIN:
|
||||
log.info('Received Cancellation')
|
||||
rawmessage = self.cancel.recv()
|
||||
self.cancel.send('')
|
||||
self.shutdown(soft=True)
|
||||
break
|
||||
#if socks.get(self.cancel) == self.zmq.POLLIN:
|
||||
# self.logging.info('[Controller] Received Cancellation')
|
||||
# rawmessage = self.cancel.recv()
|
||||
# self.cancel.send('')
|
||||
# self.shutdown(soft=True)
|
||||
# break
|
||||
|
||||
self.beat()
|
||||
|
||||
|
||||
@@ -261,7 +261,7 @@ class PerformanceTracker(object):
|
||||
self.todays_performance.calculate_performance()
|
||||
|
||||
def handle_market_close(self):
|
||||
|
||||
|
||||
# add the return results from today to the list of DailyReturn objects.
|
||||
todays_date = self.market_close.replace(hour=0, minute=0, second=0)
|
||||
todays_return_obj = risk.DailyReturn(
|
||||
@@ -347,7 +347,7 @@ class PerformanceTracker(object):
|
||||
if self.results_socket:
|
||||
log.info("about to stream the risk report...")
|
||||
risk_dict = self.risk_report.to_dict()
|
||||
|
||||
|
||||
msg = zp.RISK_FRAME(risk_dict)
|
||||
self.results_socket.send(msg)
|
||||
# this signals that the simulation is complete.
|
||||
|
||||
+17
-11
@@ -95,7 +95,7 @@ class SimulatedTrading(object):
|
||||
:param config: a dict with the following required properties::
|
||||
|
||||
- algorithm: a class that follows the algorithm protocol. See
|
||||
:py:meth:`zipline.finance.trading.TradingSimulationClient.add_algorithm
|
||||
:py:meth:`zipline.finance.trading.TradeSimulationClient.add_algorithm
|
||||
for details.
|
||||
- trading_environment: an instance of
|
||||
:py:class:`zipline.trading.TradingEnvironment`
|
||||
@@ -123,16 +123,14 @@ class SimulatedTrading(object):
|
||||
'feed_address' : sockets[2],
|
||||
'merge_address' : sockets[3],
|
||||
'result_address' : sockets[4],
|
||||
'order_address' : sockets[5]
|
||||
'order_address' : sockets[5]
|
||||
}
|
||||
|
||||
self.con = Controller(
|
||||
sockets[6],
|
||||
sockets[7],
|
||||
)
|
||||
|
||||
self.con.cancel_socket = self.allocator.lease(1)[0]
|
||||
|
||||
|
||||
# TODO: Not freeform
|
||||
self.con.manage(
|
||||
'freeform'
|
||||
@@ -143,9 +141,11 @@ class SimulatedTrading(object):
|
||||
self.sim = config['simulator_class'](addresses)
|
||||
|
||||
self.clients = {}
|
||||
|
||||
self.trading_client = TradeSimulationClient(
|
||||
self.trading_environment,
|
||||
self.sim_style
|
||||
self.sim_style,
|
||||
config['log_socket']
|
||||
)
|
||||
self.add_client(self.trading_client)
|
||||
|
||||
@@ -167,7 +167,7 @@ class SimulatedTrading(object):
|
||||
def create_test_zipline(**config):
|
||||
"""
|
||||
:param config: A configuration object that is a dict with:
|
||||
|
||||
|
||||
- environment - a \
|
||||
:py:class:`zipline.finance.trading.TradingEnvironment`
|
||||
- allocator - a :py:class:`zipline.simulator.AddressAllocator`
|
||||
@@ -190,7 +190,7 @@ class SimulatedTrading(object):
|
||||
a SIMULATION_STYLE as defined in :py:mod:`zipline.finance.trading`
|
||||
"""
|
||||
assert isinstance(config, dict)
|
||||
|
||||
|
||||
allocator = config['allocator']
|
||||
sid = config['sid']
|
||||
|
||||
@@ -252,6 +252,11 @@ class SimulatedTrading(object):
|
||||
order_amount,
|
||||
order_count
|
||||
)
|
||||
|
||||
if config.has_key('log_socket'):
|
||||
log_socket = config['log_socket']
|
||||
else:
|
||||
log_socket = None
|
||||
#-------------------
|
||||
# Simulation
|
||||
#-------------------
|
||||
@@ -260,7 +265,8 @@ class SimulatedTrading(object):
|
||||
'trading_environment' : trading_environment,
|
||||
'allocator' : allocator,
|
||||
'simulator_class' : simulator_class,
|
||||
'simulation_style' : simulation_style
|
||||
'simulation_style' : simulation_style,
|
||||
'log_socket' : log_socket
|
||||
})
|
||||
#-------------------
|
||||
|
||||
@@ -301,8 +307,8 @@ class SimulatedTrading(object):
|
||||
def get_cumulative_performance(self):
|
||||
return self.trading_client.perf.cumulative_performance.to_dict()
|
||||
|
||||
def publish_to(self, result_socket):
|
||||
self.trading_client.perf.publish_to(result_socket)
|
||||
def publish_to(self, results_socket):
|
||||
self.trading_client.perf.publish_to(results_socket)
|
||||
|
||||
def allocate_sockets(self, n):
|
||||
"""
|
||||
|
||||
@@ -118,6 +118,7 @@ import msgpack
|
||||
import numbers
|
||||
import datetime
|
||||
import pytz
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from utils.protocol_utils import Enum, FrameExceptionFactory, ndict, namelookup
|
||||
@@ -601,3 +602,57 @@ SIMULATION_STYLE = Enum(
|
||||
'FIXED_SLIPPAGE',
|
||||
'NOOP'
|
||||
)
|
||||
|
||||
#Global variables for the fields we extract out of a standard logbook record.
|
||||
LOG_FIELDS = set(['func_name', 'lineno', 'time', 'msg',\
|
||||
'level', 'channel', ])
|
||||
LOG_EXTRA_FIELDS = set(['algo_dt',])
|
||||
|
||||
def LOG_FRAME(payload):
|
||||
"""
|
||||
Expects a dictionary of the form:
|
||||
{
|
||||
'algo_dt' : 1199223000, #Algo simulation date.
|
||||
'time' : 1199223001, #Realtime date of log creation.
|
||||
'func_name' : 'foo',
|
||||
'lineno' : 46,
|
||||
'msg' : 'Successfully disintegrated llama #3',
|
||||
'level' : 4, #Logbook enum
|
||||
'channel' : 'MyLogger'
|
||||
}
|
||||
|
||||
Frame checks that we have all expected fields and exports an
|
||||
event/payload dict as JSON.
|
||||
"""
|
||||
|
||||
assert isinstance(payload, dict), \
|
||||
"LOG_FRAME expected a dict"
|
||||
|
||||
assert payload.has_key('algo_dt'), \
|
||||
"LOG_FRAME with no algo_dt"
|
||||
assert payload.has_key('time'), \
|
||||
"LOG_FRAME with no time"
|
||||
assert payload.has_key('channel'),\
|
||||
"LOG_FRAME with no channel"
|
||||
assert payload.has_key('level'),\
|
||||
"LOG_FRAME with no level"
|
||||
assert payload.has_key('msg'),\
|
||||
"LOG_FRAME with no message"
|
||||
|
||||
data = {}
|
||||
data['e'] = 'LOG'
|
||||
data['p'] = payload
|
||||
|
||||
return msgpack.dumps(data)
|
||||
|
||||
def LOG_UNFRAME(msg):
|
||||
"""
|
||||
Expects a json serialized dictionary in event/payload format.
|
||||
"""
|
||||
record = msgpack.loads(msg)
|
||||
assert record['e'] == 'LOG'
|
||||
assert record.has_key('p')
|
||||
|
||||
return record['p']
|
||||
|
||||
|
||||
|
||||
@@ -137,3 +137,24 @@ class NoopAlgorithm(object):
|
||||
|
||||
def get_sid_filter(self):
|
||||
return None
|
||||
|
||||
class TestPrintAlgorithm():
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def initialize(self):
|
||||
print "Initializing..."
|
||||
|
||||
def set_order(self, order_callable):
|
||||
pass
|
||||
|
||||
def set_portfolio(self, portfolio):
|
||||
pass
|
||||
|
||||
def handle_data(self, data):
|
||||
print "Handling Data..."
|
||||
pass
|
||||
|
||||
def get_sid_filter(self):
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
import logbook
|
||||
import zmq
|
||||
import pytz
|
||||
import datetime
|
||||
|
||||
from logbook import NOTSET
|
||||
from logbook.handlers import Handler
|
||||
|
||||
from zipline.protocol import LOG_FRAME, LOG_FIELDS, \
|
||||
LOG_EXTRA_FIELDS
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
class redirecter(object):
|
||||
def __init__(self, logger, name):
|
||||
self.logger = logger
|
||||
self.buffer = bytes()
|
||||
self.name = name
|
||||
|
||||
def write(self, line):
|
||||
self.buffer += ''.join(['>>> ', line.strip('\n'), '\n'])
|
||||
|
||||
def flush(self, final=False):
|
||||
if not self.buffer:
|
||||
return
|
||||
out_form = """ [{pipe_name}] \n{buffer}""".format(
|
||||
pipe_name = self.name,
|
||||
buffer = self.buffer
|
||||
)
|
||||
self.logger.error(out_form)
|
||||
self.buffer = bytes()
|
||||
|
||||
class log_redirecter(object):
|
||||
def __init__(self, logger):
|
||||
self.logger = logger
|
||||
|
||||
def write(self, line):
|
||||
#Absorb blank lines from print statements.
|
||||
if line =='\n':
|
||||
return
|
||||
|
||||
else:
|
||||
#TODO: add logic to guarantee we made this
|
||||
self.logger.info(line.strip('\n'))
|
||||
|
||||
def flush(self, final=False):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def stdout_pipe(logger, pipe_name):
|
||||
"""
|
||||
Pipe stdout and stderr into a python logger interface
|
||||
"""
|
||||
import sys
|
||||
orig_fds = sys.stdout, sys.stderr
|
||||
|
||||
sys.stderr = redirecter(logger, pipe_name)
|
||||
sys.stdout = redirecter(logger, pipe_name)
|
||||
|
||||
yield
|
||||
sys.stderr.flush()
|
||||
sys.stdout.flush()
|
||||
sys.stdout, sys.stderr = orig_fds
|
||||
|
||||
@contextmanager
|
||||
def stdout_only_pipe(logger, pipe_name):
|
||||
"""
|
||||
Pipes just stdout into a python logger interface
|
||||
"""
|
||||
import sys
|
||||
orig_fd = sys.stdout
|
||||
sys.stdout = log_redirecter(logger)
|
||||
|
||||
yield
|
||||
sys.stdout.flush()
|
||||
sys.stdout = orig_fd
|
||||
|
||||
class ZeroMQLogHandler(Handler):
|
||||
"""
|
||||
A handler that takes messages captured from the user algorithm stdout
|
||||
and transforms them into LOG_FRAMES suitable for database storage.
|
||||
Setup is similar to logbook.queues.ZeroMQHandler, except we connect
|
||||
instead of binding and we extract record fields into a dict.
|
||||
"""
|
||||
|
||||
def __init__(self, uri=None, level=NOTSET, filter=None, bubble=False,
|
||||
context=None, fds = LOG_FIELDS, extra_fds = LOG_EXTRA_FIELDS):
|
||||
Handler.__init__(self, level, filter, bubble)
|
||||
try:
|
||||
import zmq
|
||||
except ImportError:
|
||||
raise RuntimeError('The pyzmq library is required for '
|
||||
'the ZeroMQHandler.')
|
||||
#: the zero mq context
|
||||
self.context = context
|
||||
#: the zero mq socket.
|
||||
self.socket = self.context.socket(zmq.PUSH)
|
||||
|
||||
self.uri = uri
|
||||
if uri is not None:
|
||||
self.socket.connect(uri)
|
||||
|
||||
self.fds = fds
|
||||
self.extra_fds = extra_fds
|
||||
|
||||
def export_record(self, record):
|
||||
"""
|
||||
Extract relevant fields from a log record, fiddling with datetime
|
||||
fields to make json happy.
|
||||
"""
|
||||
from zipline.utils.date_utils import EPOCH
|
||||
|
||||
#Needed to extract record info from dictionary.
|
||||
record.pull_information()
|
||||
|
||||
#Logbook stores record times as datetime objects, which
|
||||
#can't be serialized by JSON, so we need to convert to
|
||||
#unix epoch representation.
|
||||
|
||||
if record.time:
|
||||
assert isinstance(record.time, datetime.datetime)
|
||||
|
||||
time = record.time.replace(tzinfo = pytz.utc)
|
||||
#logbook measures time in utc already, no need to convert.
|
||||
record.time = EPOCH(time)
|
||||
|
||||
#Do the same if algo_dt is a datetime object.
|
||||
if record.extra.has_key('algo_dt'):
|
||||
algo_dt = record.extra['algo_dt']
|
||||
|
||||
if isinstance(algo_dt, datetime.datetime):
|
||||
algo_dt = EPOCH(algo_dt.replace(tzinfo = pytz.utc))
|
||||
record.extra['algo_dt'] = algo_dt
|
||||
|
||||
data = {}
|
||||
|
||||
#Extract all the fields we care about from LogRecord's internal
|
||||
#dictionary.
|
||||
|
||||
for field in iter(self.fds):
|
||||
if record.__dict__.has_key(field):
|
||||
data[field] = record.__dict__[field]
|
||||
else:
|
||||
data[field] = None
|
||||
|
||||
for field in iter(self.extra_fds):
|
||||
if record.extra.has_key(field):
|
||||
data[field] = record.extra[field]
|
||||
else:
|
||||
data[field] = None
|
||||
return data
|
||||
|
||||
def emit(self, record):
|
||||
"""Extract relevant fields and send info as JSON over a zmq socket."""
|
||||
payload = self.export_record(record)
|
||||
self.socket.send(LOG_FRAME(payload))
|
||||
|
||||
def close(self):
|
||||
#self.socket.close()
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user