Merge pull request #72 from quantopian/scott1

Scott1
This commit is contained in:
Scott Sanderson
2012-07-06 07:58:33 -07:00
9 changed files with 347 additions and 37 deletions
+77 -10
View File
@@ -8,11 +8,16 @@ from zipline.core.component import Component
from zipline.finance.trading import TransactionSimulator
from zipline.utils.protocol_utils import ndict
from zipline.utils.log_utils import ZeroMQLogHandler, stdout_only_pipe
from logbook import Logger, NestedSetup, Processor, queues
log = logbook.Logger('TradeSimulation')
class TradeSimulationClient(Component):
def init(self, trading_environment, sim_style):
def init(self, trading_environment, sim_style, log_socket):
self.received_count = 0
self.prev_dt = None
self.event_queue = None
@@ -28,12 +33,15 @@ class TradeSimulationClient(Component):
self.event_data = ndict()
self.perf = perf.PerformanceTracker(self.trading_environment)
self.log_socket = log_socket
@property
def get_id(self):
return str(zp.FINANCE_COMPONENT.TRADING_CLIENT)
def set_algorithm(self, algorithm):
"""
:param algorithm: must implement the algorithm protocol. See
:py:mod:`zipline.test.algorithm`
@@ -41,17 +49,41 @@ class TradeSimulationClient(Component):
self.algorithm = algorithm
# register the trading_client's order method with the algorithm
self.algorithm.set_order(self.order)
# ask the algorithm to initialize
self.algorithm.initialize()
#TODO: re-enable initialization logging. This means we can't call set_algorithm
#until we have a context for this component. Possibly this could happen
# ask the algorithm to initialize, routing stdout to a zmq PUSH socket.
#with self.zmq_out.threadbound(), self.stdout_capture(self.logger, 'Algo print capture'):
# self.algorithm.initialize()
#if we don't have a log socket, initialize anyway.
#else:
# self.algorithm.initialize()
self.algorithm.initialize()
def open(self):
self.result_feed = self.connect_result()
self.perf.open(self.context)
#If we have a log socket,setup context manager for exporting captured
#print statements
if self.log_socket:
self.zmq_out = ZeroMQLogHandler(uri = self.log_socket, context = self.context)
self.logger = Logger("Print")
self.stdout_capture = stdout_only_pipe #THIS IS A CLASS!
#Initialize log capture for testing purposes.
def setup_logging(self, context):
if self.log_socket:
self.zmq_out = ZeroMQLogHandler(uri = self.log_socket, context = context)
self.logger = Logger("Print")
self.stdout_capture = stdout_only_pipe #THIS IS A CLASS!
def do_work(self):
# poll all the sockets
socks = dict(self.poll.poll(self.heartbeat_timeout))
# see if the poller has results for the result_feed
if socks.get(self.result_feed) == self.zmq.POLLIN:
@@ -69,6 +101,7 @@ class TradeSimulationClient(Component):
event = zp.MERGE_UNFRAME(msg)
self.received_count += 1
# update performance and relay the event to the algorithm
self.process_event(event)
if self.perf.exceeded_max_loss:
self.finish_simulation()
@@ -82,9 +115,8 @@ class TradeSimulationClient(Component):
# signal Simulator, our ComponentHost, that this component is
# done and Simulator needn't block exit on this component.
self.signal_done()
def process_event(self, event):
# generate transactions, if applicable
txn = self.txn_sim.apply_trade_to_open_orders(event)
if txn:
@@ -109,10 +141,10 @@ class TradeSimulationClient(Component):
# queue the event.
self.queue_event(event)
# if the event is later than our current time, run the algo
# otherwise, the algorithm has fallen behind the feed
# and processing per event is longer than time between events.
if event.dt >= self.current_dt:
# compress time by moving the current_time up to the event
# time.
@@ -124,7 +156,6 @@ class TradeSimulationClient(Component):
# move the algorithm's clock forward to include iteration time
self.current_dt = self.current_dt + self.last_iteration_dur
def run_algorithm(self):
"""
As per the algorithm protocol:
@@ -135,8 +166,44 @@ class TradeSimulationClient(Component):
current_portfolio = self.perf.get_portfolio()
self.algorithm.set_portfolio(current_portfolio)
data = self.get_data()
if len(data) > 0:
self.algorithm.handle_data(data)
# data injection pipeline for log rerouting
# any fields injected here should be added to
# LOG_EXTRA_FIELDS in zipline/protocol.py
if self.log_socket:
def inject_event_data(record):
#Record the simulation time.
record.extra['algo_dt'] = self.current_dt
data_injector = Processor(inject_event_data)
log_pipeline = NestedSetup([self.zmq_out,
#e.g. FileHandler(...)
data_injector])
with log_pipeline.threadbound(), self.stdout_capture(self.logger, ''):
self.algorithm.handle_data(data)
# if no log socket, just run the algo normally
else:
self.algorithm.handle_data(data)
#Testing utility for log capture.
def test_run_algorithm(self):
def inject_event_data(record):
record.extra['algo_dt'] = datetime.datetime.utcnow() #Mock an event.dt
data_injector = Processor(inject_event_data)
log_pipeline = NestedSetup([self.zmq_out,
#e.g. FileHandler(...)
data_injector])
with log_pipeline.threadbound(), self.stdout_capture(self.logger, ''):
self.algorithm.handle_data('data')
# if no log socket, just run the algo normally
def connect_order(self):
return self.connect_push_socket(self.addresses['order_address'])
+1 -1
View File
@@ -213,7 +213,7 @@ class Component(object):
Run the component.
Optionally takes an argument to catch and log all exceptions
raised during execution ues this with care since it makes it
raised during execution. Use this with care since it makes it
very hard to debug since it mucks up your stacktraces.
"""
+1 -1
View File
@@ -1,5 +1,5 @@
"""
Simulator hosts all the components necessary to execute a simluation.
Simulator hosts all the components necessary to execute a simulation.
See :py:method""
"""
+12 -12
View File
@@ -210,7 +210,7 @@ class Controller(object):
assert self.route_socket
assert self.pub_socket
assert self.cancel_socket
#assert self.cancel_socket
# -- Publish --
# =============
@@ -219,9 +219,9 @@ class Controller(object):
# -- Cancel --
# =============
assert isinstance(self.cancel_socket,basestring), self.cancel_socket
self.cancel = self.context.socket(self.zmq.REP)
self.cancel.connect(self.cancel_socket)
#assert isinstance(self.cancel_socket,basestring), self.cancel_socket
#self.cancel = self.context.socket(self.zmq.REP)
#self.cancel.connect(self.cancel_socket)
# -- Router --
# =============
@@ -231,9 +231,9 @@ class Controller(object):
poller = self.zmq.Poller()
poller.register(self.router, self.zmq.POLLIN)
poller.register(self.cancel, self.zmq.POLLIN)
#poller.register(self.cancel, self.zmq.POLLIN)
self.associated += [self.pub, self.router, self.cancel]
self.associated += [self.pub, self.router]# self.cancel]
# TODO: actually do this
self.state = CONTROL_STATES.SOURCES_READY
@@ -270,12 +270,12 @@ class Controller(object):
log.error('Invalid frame', rawmessage)
pass
if socks.get(self.cancel) == self.zmq.POLLIN:
log.info('Received Cancellation')
rawmessage = self.cancel.recv()
self.cancel.send('')
self.shutdown(soft=True)
break
#if socks.get(self.cancel) == self.zmq.POLLIN:
# self.logging.info('[Controller] Received Cancellation')
# rawmessage = self.cancel.recv()
# self.cancel.send('')
# self.shutdown(soft=True)
# break
self.beat()
+2 -2
View File
@@ -261,7 +261,7 @@ class PerformanceTracker(object):
self.todays_performance.calculate_performance()
def handle_market_close(self):
# add the return results from today to the list of DailyReturn objects.
todays_date = self.market_close.replace(hour=0, minute=0, second=0)
todays_return_obj = risk.DailyReturn(
@@ -347,7 +347,7 @@ class PerformanceTracker(object):
if self.results_socket:
log.info("about to stream the risk report...")
risk_dict = self.risk_report.to_dict()
msg = zp.RISK_FRAME(risk_dict)
self.results_socket.send(msg)
# this signals that the simulation is complete.
+17 -11
View File
@@ -95,7 +95,7 @@ class SimulatedTrading(object):
:param config: a dict with the following required properties::
- algorithm: a class that follows the algorithm protocol. See
:py:meth:`zipline.finance.trading.TradingSimulationClient.add_algorithm
:py:meth:`zipline.finance.trading.TradeSimulationClient.add_algorithm
for details.
- trading_environment: an instance of
:py:class:`zipline.trading.TradingEnvironment`
@@ -123,16 +123,14 @@ class SimulatedTrading(object):
'feed_address' : sockets[2],
'merge_address' : sockets[3],
'result_address' : sockets[4],
'order_address' : sockets[5]
'order_address' : sockets[5]
}
self.con = Controller(
sockets[6],
sockets[7],
)
self.con.cancel_socket = self.allocator.lease(1)[0]
# TODO: Not freeform
self.con.manage(
'freeform'
@@ -143,9 +141,11 @@ class SimulatedTrading(object):
self.sim = config['simulator_class'](addresses)
self.clients = {}
self.trading_client = TradeSimulationClient(
self.trading_environment,
self.sim_style
self.sim_style,
config['log_socket']
)
self.add_client(self.trading_client)
@@ -167,7 +167,7 @@ class SimulatedTrading(object):
def create_test_zipline(**config):
"""
:param config: A configuration object that is a dict with:
- environment - a \
:py:class:`zipline.finance.trading.TradingEnvironment`
- allocator - a :py:class:`zipline.simulator.AddressAllocator`
@@ -190,7 +190,7 @@ class SimulatedTrading(object):
a SIMULATION_STYLE as defined in :py:mod:`zipline.finance.trading`
"""
assert isinstance(config, dict)
allocator = config['allocator']
sid = config['sid']
@@ -252,6 +252,11 @@ class SimulatedTrading(object):
order_amount,
order_count
)
if config.has_key('log_socket'):
log_socket = config['log_socket']
else:
log_socket = None
#-------------------
# Simulation
#-------------------
@@ -260,7 +265,8 @@ class SimulatedTrading(object):
'trading_environment' : trading_environment,
'allocator' : allocator,
'simulator_class' : simulator_class,
'simulation_style' : simulation_style
'simulation_style' : simulation_style,
'log_socket' : log_socket
})
#-------------------
@@ -301,8 +307,8 @@ class SimulatedTrading(object):
def get_cumulative_performance(self):
return self.trading_client.perf.cumulative_performance.to_dict()
def publish_to(self, result_socket):
self.trading_client.perf.publish_to(result_socket)
def publish_to(self, results_socket):
self.trading_client.perf.publish_to(results_socket)
def allocate_sockets(self, n):
"""
+55
View File
@@ -118,6 +118,7 @@ import msgpack
import numbers
import datetime
import pytz
from collections import namedtuple
from utils.protocol_utils import Enum, FrameExceptionFactory, ndict, namelookup
@@ -601,3 +602,57 @@ SIMULATION_STYLE = Enum(
'FIXED_SLIPPAGE',
'NOOP'
)
#Global variables for the fields we extract out of a standard logbook record.
LOG_FIELDS = set(['func_name', 'lineno', 'time', 'msg',\
'level', 'channel', ])
LOG_EXTRA_FIELDS = set(['algo_dt',])
def LOG_FRAME(payload):
"""
Expects a dictionary of the form:
{
'algo_dt' : 1199223000, #Algo simulation date.
'time' : 1199223001, #Realtime date of log creation.
'func_name' : 'foo',
'lineno' : 46,
'msg' : 'Successfully disintegrated llama #3',
'level' : 4, #Logbook enum
'channel' : 'MyLogger'
}
Frame checks that we have all expected fields and exports an
event/payload dict as JSON.
"""
assert isinstance(payload, dict), \
"LOG_FRAME expected a dict"
assert payload.has_key('algo_dt'), \
"LOG_FRAME with no algo_dt"
assert payload.has_key('time'), \
"LOG_FRAME with no time"
assert payload.has_key('channel'),\
"LOG_FRAME with no channel"
assert payload.has_key('level'),\
"LOG_FRAME with no level"
assert payload.has_key('msg'),\
"LOG_FRAME with no message"
data = {}
data['e'] = 'LOG'
data['p'] = payload
return msgpack.dumps(data)
def LOG_UNFRAME(msg):
"""
Expects a json serialized dictionary in event/payload format.
"""
record = msgpack.loads(msg)
assert record['e'] == 'LOG'
assert record.has_key('p')
return record['p']
+21
View File
@@ -137,3 +137,24 @@ class NoopAlgorithm(object):
def get_sid_filter(self):
return None
class TestPrintAlgorithm():
def __init__(self):
pass
def initialize(self):
print "Initializing..."
def set_order(self, order_callable):
pass
def set_portfolio(self, portfolio):
pass
def handle_data(self, data):
print "Handling Data..."
pass
def get_sid_filter(self):
return None
+161
View File
@@ -0,0 +1,161 @@
import logbook
import zmq
import pytz
import datetime
from logbook import NOTSET
from logbook.handlers import Handler
from zipline.protocol import LOG_FRAME, LOG_FIELDS, \
LOG_EXTRA_FIELDS
from contextlib import contextmanager
class redirecter(object):
def __init__(self, logger, name):
self.logger = logger
self.buffer = bytes()
self.name = name
def write(self, line):
self.buffer += ''.join(['>>> ', line.strip('\n'), '\n'])
def flush(self, final=False):
if not self.buffer:
return
out_form = """ [{pipe_name}] \n{buffer}""".format(
pipe_name = self.name,
buffer = self.buffer
)
self.logger.error(out_form)
self.buffer = bytes()
class log_redirecter(object):
def __init__(self, logger):
self.logger = logger
def write(self, line):
#Absorb blank lines from print statements.
if line =='\n':
return
else:
#TODO: add logic to guarantee we made this
self.logger.info(line.strip('\n'))
def flush(self, final=False):
pass
@contextmanager
def stdout_pipe(logger, pipe_name):
"""
Pipe stdout and stderr into a python logger interface
"""
import sys
orig_fds = sys.stdout, sys.stderr
sys.stderr = redirecter(logger, pipe_name)
sys.stdout = redirecter(logger, pipe_name)
yield
sys.stderr.flush()
sys.stdout.flush()
sys.stdout, sys.stderr = orig_fds
@contextmanager
def stdout_only_pipe(logger, pipe_name):
"""
Pipes just stdout into a python logger interface
"""
import sys
orig_fd = sys.stdout
sys.stdout = log_redirecter(logger)
yield
sys.stdout.flush()
sys.stdout = orig_fd
class ZeroMQLogHandler(Handler):
"""
A handler that takes messages captured from the user algorithm stdout
and transforms them into LOG_FRAMES suitable for database storage.
Setup is similar to logbook.queues.ZeroMQHandler, except we connect
instead of binding and we extract record fields into a dict.
"""
def __init__(self, uri=None, level=NOTSET, filter=None, bubble=False,
context=None, fds = LOG_FIELDS, extra_fds = LOG_EXTRA_FIELDS):
Handler.__init__(self, level, filter, bubble)
try:
import zmq
except ImportError:
raise RuntimeError('The pyzmq library is required for '
'the ZeroMQHandler.')
#: the zero mq context
self.context = context
#: the zero mq socket.
self.socket = self.context.socket(zmq.PUSH)
self.uri = uri
if uri is not None:
self.socket.connect(uri)
self.fds = fds
self.extra_fds = extra_fds
def export_record(self, record):
"""
Extract relevant fields from a log record, fiddling with datetime
fields to make json happy.
"""
from zipline.utils.date_utils import EPOCH
#Needed to extract record info from dictionary.
record.pull_information()
#Logbook stores record times as datetime objects, which
#can't be serialized by JSON, so we need to convert to
#unix epoch representation.
if record.time:
assert isinstance(record.time, datetime.datetime)
time = record.time.replace(tzinfo = pytz.utc)
#logbook measures time in utc already, no need to convert.
record.time = EPOCH(time)
#Do the same if algo_dt is a datetime object.
if record.extra.has_key('algo_dt'):
algo_dt = record.extra['algo_dt']
if isinstance(algo_dt, datetime.datetime):
algo_dt = EPOCH(algo_dt.replace(tzinfo = pytz.utc))
record.extra['algo_dt'] = algo_dt
data = {}
#Extract all the fields we care about from LogRecord's internal
#dictionary.
for field in iter(self.fds):
if record.__dict__.has_key(field):
data[field] = record.__dict__[field]
else:
data[field] = None
for field in iter(self.extra_fds):
if record.extra.has_key(field):
data[field] = record.extra[field]
else:
data[field] = None
return data
def emit(self, record):
"""Extract relevant fields and send info as JSON over a zmq socket."""
payload = self.export_record(record)
self.socket.send(LOG_FRAME(payload))
def close(self):
#self.socket.close()
pass