goodbye old components.

This commit is contained in:
fawce
2012-08-08 22:51:37 -04:00
parent 7e86652a75
commit 853c2ea61e
11 changed files with 5 additions and 688 deletions
+2 -2
View File
@@ -3,7 +3,7 @@ import pytz
from pprint import pformat as pf
from datetime import datetime, timedelta
from unittest2 import TestCase
from unittest2 import TestCase, skip
from collections import defaultdict
from zipline.gens.composites import date_sorted_sources, merged_transforms
@@ -218,7 +218,7 @@ class ComponentTestCase(TestCase):
comp_c.proc.join()
mon_proc.join()
@skip
def test_full(self):
monitor = create_monitor(allocator)
+2 -1
View File
@@ -6,7 +6,8 @@ from collections import defaultdict
import numpy as np
from zipline.core.devsimulator import AddressAllocator
from zipline.optimize.factory import create_predictable_zipline
# TODO: refactor the factory to use generators
# from zipline.optimize.factory import create_predictable_zipline
DEFAULT_TIMEOUT = 15 # seconds
EXTENDED_TIMEOUT = 90
-13
View File
@@ -1,13 +0,0 @@
from feed import Feed
from merge import Merge
from passthrough import PassthroughTransform
from datasource import DataSource
from tradesimulation import TradeSimulationClient
__all__ = [
Feed,
Merge,
PassthroughTransform,
DataSource,
TradeSimulationClient,
]
-144
View File
@@ -1,144 +0,0 @@
"""
Abstract base class for Feed and Merge.
Component
|
Aggregate
|
/ \
Feed Merge
"""
import logbook
import zipline.protocol as zp
from zipline.core.component import Component
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE
from zipline.transitions import WorkflowMeta
from zipline.utils.protocol_utils import Enum
log = logbook.Logger('Aggregate')
# =================
# State Transitions
# =================
INIT, READY, DRAINING = AGGREGATE_STATES = \
Enum( 'INIT', 'READY', 'DRAINING')
AGGREGATE_TRANSITIONS = dict(
do_start = (-1 , INIT) ,
do_run = (INIT , READY) ,
do_drain = (READY , DRAINING) ,
)
# =========
# Component
# =========
class Aggregate(Component):
"""
Abstract superclass to Merge & Feed. Acts on two sockets
- pull_socket
- feed_socket
Both use ``sources`` for buffering.
Feed and Merge define these differently.
"""
abstract = True
__metaclass__ = WorkflowMeta
@property
def get_type(self):
return COMPONENT_TYPE.CONDUIT
def add_source(self, source_id):
self.sources[source_id] = []
# -------------
# Core Methods
# -------------
def do_work(self):
# -------------
# Work Dispatch
# -------------
if self.socks.get(self.pull_socket) == self.zmq.POLLIN:
message = self.pull_socket.recv()
if message == str(CONTROL_PROTOCOL.DONE):
self.ds_finished_counter += 1
if len(self.sources) == self.ds_finished_counter:
# Drain any remaining messages in the buffer
log.debug("Draining Feed")
self.state = DRAINING
self.drain()
self.signal_done()
else:
event = self.unframe(message)
self.append(event)
if self.is_full():
event = self.next()
if event:
self.send(event)
else:
pass
# -------------
# Flow Control
# -------------
def drain(self):
"""
Send all messages in the buffer.
"""
while self.pending_messages() > 0:
event = self.next()
self.heartbeat()
if event:
self.send(event)
def send(self, event):
"""
Send the (chronologically) next message in the buffer.
"""
self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK)
self.sent_counters[event.source_id] += 1
self.sent_count += 1
def is_full(self):
"""
Indicates whether the buffer has messages in buffer for all
un-DONE, blocking sources.
"""
for source_id, events in self.sources.iteritems():
if len(events) == 0:
return False
return True
def pending_messages(self):
"""
Returns the count of all events from all sources in the
buffer.
"""
total = 0
for events in self.sources.itervalues():
total += len(events)
return total
def __len__(self):
"""
Buffer's length is same as internal map holding separate
sorted arrays of events keyed by source id.
"""
return len(self.sources)
-70
View File
@@ -1,70 +0,0 @@
"""
Commonly used messaging components.
"""
import zipline.protocol as zp
from zipline.core.component import Component
from zipline.protocol import COMPONENT_TYPE
class DataSource(Component):
"""
Abstract baseclass for data sources. Subclass and implement send_all
- usually this means looping through all records in a store,
converting to a dict, and calling send(map).
Every datasource has a dict property to hold filters::
- key -- name of the filter, e.g. sid
- value -- a primitive representing the filter. e.g. a list of ints.
Modify the datasource's filters via the set_filter(name, value)
"""
def set_filter(self, name, value):
self.filter[name] = value
def setup_source(self):
self.filter = {}
self.cur_event = None
@property
def get_id(self):
"""
Returns this component id, this is fixed at a class level. This
should not and cannot be contingent on arguments to the init
function. Examples:
- "TradeDataSource"
- "RandomEquityTrades"
- "SpecificEquityTrades"
"""
raise NotImplementedError
@property
def get_type(self):
return COMPONENT_TYPE.SOURCE
def open(self):
self.data_socket = self.connect_data()
def send(self, event):
"""
Emit data.
"""
assert isinstance(event, zp.ndict)
event['source_id'] = self.get_id
event['type'] = self.get_type
try:
ds_frame = self.frame(event)
except zp.INVALID_DATASOURCE_FRAME as exc:
return self.signal_exception(exc)
self.data_socket.send(ds_frame)
def frame(self, event):
return zp.DATASOURCE_FRAME(event)
def do_work(self):
raise NotImplementedError()
-111
View File
@@ -1,111 +0,0 @@
import logbook
from collections import defaultdict, Counter
from zipline.components.aggregator import Aggregate, \
AGGREGATE_STATES, AGGREGATE_TRANSITIONS
import zipline.protocol as zp
log = logbook.Logger('Feed')
# =========
# Component
# =========
class Feed(Aggregate):
"""
Connects to N PULL sockets, publishing all messages received to a
PUB socket. Published messages are guaranteed to be in chronological
order based on message property dt. Expects to be instantiated in
one execution context (thread, process, etc) and run in another.
"""
states = list(AGGREGATE_STATES)
transitions = AGGREGATE_TRANSITIONS
initial_state = -1
def init(self):
self.sent_count = 0
self.received_count = 0
self.ds_finished_counter = 0
self.sources = defaultdict(list)
# source_id -> integer count
self.sent_counters = Counter()
self.recv_counters = Counter()
self.state = AGGREGATE_STATES.INIT
@property
def get_id(self):
return "FEED"
@property
def draining(self):
return self.state == AGGREGATE_STATES.DRAINING
# -------
# Sockets
# -------
def open(self):
self.pull_socket = self.bind_data()
self.feed_socket = self.bind_feed()
# -------
# Framing
# -------
def unframe(self, msg):
return zp.DATASOURCE_UNFRAME(msg)
def frame(self, event):
return zp.FEED_FRAME(event)
# -------------
# Flow Control
# -------------
def append(self, event):
"""
Add an event to the buffer for the source specified by
source_id.
"""
self.sources[event.source_id].append(event)
self.recv_counters[event.source_id] += 1
self.received_count += 1
def next(self):
"""
Get the next message in chronological order.
"""
# TODO: this is redundant to the guard in aggregator.
# is_full and draining defined in aggregator
if not(self.is_full() or self.draining):
return
earliest_source = None
earliest_event = None
# iterate over the queues of source from all sources
# (1 queue per datasource)
for source in self.sources.itervalues():
if len(source) == 0:
continue
head = source[0]
if head.dt == None:
#this is a filler event, discard
source.pop(0)
continue
if (earliest_event == None) or (head.dt <= earliest_event.dt):
earliest_event = head
earliest_source = source
if earliest_event != None:
return earliest_source.pop(0)
else:
return False
-83
View File
@@ -1,83 +0,0 @@
import zipline.protocol as zp
from zipline.components.aggregator import Aggregate, \
AGGREGATE_STATES, AGGREGATE_TRANSITIONS
from collections import defaultdict, Counter
class Merge(Aggregate):
"""
Merges multiple streams of events into single messages.
"""
states = list(AGGREGATE_STATES)
transitions = AGGREGATE_TRANSITIONS
initial_state = -1
def init(self):
self.sent_count = 0
self.received_count = 0
self.draining = False
self.ds_finished_counter = 0
self.sources = defaultdict(list)
# source_id -> integer count
self.sent_counters = Counter()
self.recv_counters = Counter()
@property
def get_id(self):
return "MERGE"
# -------
# Sockets
# -------
def open(self):
self.pull_socket = self.bind_merge()
self.feed_socket = self.bind_result()
# -------
# Framing
# -------
def unframe(self, msg):
return zp.TRANSFORM_UNFRAME(msg)
def frame(self, event):
return zp.MERGE_FRAME(event)
# ---------
# Data Flow
# ---------
def append(self, event):
"""
:param event: a ndict with one entry. key is the name of the
transform, value is the transformed value.
Add an event to the buffer for the source specified by
source_id.
"""
self.sources[event.keys()[0]].append(event)
self.received_count += 1
def next(self):
"""Get the next merged message from the feed buffer."""
if not (self.is_full() or self.draining):
return
if self.pending_messages() == 0:
return
#get the raw event from the passthrough transform.
passthrough = self.sources[zp.TRANSFORM_TYPE.PASSTHROUGH]
result = passthrough.pop(0).PASSTHROUGH
for source, events in self.sources.iteritems():
if source == zp.TRANSFORM_TYPE.PASSTHROUGH:
continue
if len(events) > 0:
cur = events.pop(0)
result.merge(cur)
return result
-18
View File
@@ -1,18 +0,0 @@
from zipline.transforms import BaseTransform
from zipline.protocol import FEED_FRAME, TRANSFORM_TYPE
class PassthroughTransform(BaseTransform):
"""
A bypass transform passes data through unchanged.
"""
def init(self):
self.props = { 'name': 'PASSTHROUGH' }
#TODO, could save some cycles by skipping the _UNFRAME call
# and just setting value to original msg string.
def transform(self, event):
return {
'name' : TRANSFORM_TYPE.PASSTHROUGH,
'value' : FEED_FRAME(event)
}
-243
View File
@@ -1,243 +0,0 @@
import logbook
import datetime
import zmq
import zipline.protocol as zp
import zipline.finance.performance as perf
from zipline.core.component import Component
from zipline.finance.trading import TransactionSimulator
from zipline.utils.protocol_utils import ndict
from zipline.utils.log_utils import ZeroMQLogHandler, stdout_only_pipe
from logbook import Logger, NestedSetup, Processor
log = logbook.Logger('TradeSimulation')
class TradeSimulationClient(Component):
def init(self, trading_environment, sim_style, results_socket, algorithm):
self.received_count = 0
self.prev_dt = None
self.event_queue = None
self.txn_count = 0
self.order_count = 0
self.trading_environment = trading_environment
self.current_dt = trading_environment.period_start
self.last_iteration_dur = datetime.timedelta(seconds=0)
self.algorithm = algorithm
self.algorithm.set_order(self.order)
self.max_wait = datetime.timedelta(seconds=60)
self.last_msg_dt = datetime.datetime.utcnow()
self.txn_sim = TransactionSimulator(
open_orders={},
style=sim_style
)
self.event_data = ndict()
self.perf = perf.PerformanceTracker(
self.trading_environment,
self.algorithm.get_sid_filter()
)
self.zmq_out = None
self.results_socket = results_socket
self.algo_initialized = False
@property
def get_id(self):
return str(zp.FINANCE_COMPONENT.TRADING_CLIENT)
def open(self):
self.result_feed = self.connect_result()
if self.results_socket:
sock = self.context.socket(zmq.PUSH)
sock.connect(self.results_socket)
self.results_socket = sock
self.sockets.append(sock)
self.out_socket = sock
self.setup_logging(sock)
self.perf.publish_to(sock)
def initialize_algo(self):
""" Setup loggers for algorithm and run algorithm's own
initialize method.
"""
self.logger = Logger("Print")
self.algo_log = Logger("AlgoLog")
self.algorithm.set_logger(self.algo_log)
self.do_op(self.algorithm.initialize)
self.algo_initialized = True
def setup_logging(self, socket = None):
sock = socket or self.results_socket
self.zmq_out = ZeroMQLogHandler(
socket = sock,
)
# This is a class, which is instantiated later
# in run_algorithm. The class provides a generator.
self.stdout_capture = stdout_only_pipe
def do_work(self):
if not self.algo_initialized:
self.initialize_algo()
# see if the poller has results for the result_feed
if self.socks.get(self.result_feed) == self.zmq.POLLIN:
self.last_msg_dt = datetime.datetime.utcnow()
# get the next message from the result feed
msg = self.result_feed.recv()
# if the feed is done, shut 'er down
if msg == str(zp.CONTROL_PROTOCOL.DONE):
self.finish_simulation()
return
# result_feed is a merge component, so unframe accordingly
event = zp.MERGE_UNFRAME(msg)
self.received_count += 1
# update performance and relay the event to the algorithm
self.process_event(event)
if self.perf.exceeded_max_loss:
self.finish_simulation()
def finish_simulation(self):
log.info("TradeSimulation is Done")
# signal the performance tracker that the simulation has
# ended. Perf will internally calculate the full risk report.
self.perf.handle_simulation_end()
# signal Simulator, our ComponentHost, that this component is
# done and Simulator needn't block exit on this component.
self.signal_done()
def process_event(self, event):
# generate transactions, if applicable
txn = self.txn_sim.apply_trade_to_open_orders(event)
if txn:
event.TRANSACTION = txn
# track the number of transactions, for testing purposes.
self.txn_count += 1
else:
event.TRANSACTION = None
# the performance class needs to process each event, without
# skipping. Algorithm should wait until the performance has been
# updated, so that down stream components can safely assume that
# performance is up to date. Note that this is done before we
# mark the time for the algorithm's processing, thereby not
# running the algo's clock for performance book keeping.
self.perf.process_event(event)
# mark the start time for client's processing of this event.
event_start = datetime.datetime.utcnow()
# queue the event.
self.queue_event(event)
# if the event is later than our current time, run the algo
# otherwise, the algorithm has fallen behind the feed
# and processing per event is longer than time between events.
if event.dt >= self.current_dt:
# compress time by moving the current_time up to the event
# time.
self.current_dt = event.dt
self.run_algorithm()
# tally the time spent on this iteration
self.last_iteration_dur = datetime.datetime.utcnow() - event_start
# move the algorithm's clock forward to include iteration time
self.current_dt = self.current_dt + self.last_iteration_dur
def run_algorithm(self):
"""
As per the algorithm protocol:
- Set the current portfolio for the algorithm as per protocol.
- Construct data based on backlog of events, send to algorithm.
"""
data = self.get_data()
if len(data) > 0:
data.portfolio = self.perf.get_portfolio()
# data injection pipeline for log rerouting
# any fields injected here should be added to
# LOG_EXTRA_FIELDS in zipline/protocol.py
self.do_op(self.algorithm.handle_data, data)
def do_op(self, callable_op, *args, **kwargs):
""" Wrap a callable operation with the zmq logbook
handler if it exits."""
if self.zmq_out:
def inject_event_data(record):
# Record the simulation time.
record.extra['algo_dt'] = self.current_dt
data_injector = Processor(inject_event_data)
log_pipeline = NestedSetup([self.zmq_out,data_injector])
with log_pipeline.threadbound(), self.stdout_capture(self.logger, ''):
callable_op(*args, **kwargs)
# if no log socket, just run the algo normally
else:
callable_op(*args, **kwargs)
#Testing utility for log capture.
# TODO: remove test code from here.
def test_run_algorithm(self):
# since open is never called from some tests we need to
# set the logger explicitly
self.algorithm.set_logger(self.algo_log)
def inject_event_data(record):
# Mock an event.dt
record.extra['algo_dt'] = datetime.datetime.utcnow()
data_injector = Processor(inject_event_data)
log_pipeline = NestedSetup([self.zmq_out,
#e.g. FileHandler(...)
data_injector])
with log_pipeline.threadbound(), self.stdout_capture(self.logger, ''):
self.algorithm.handle_data('data')
#def connect_order(self):
# return self.connect_push_socket(self.addresses['order_address'])
def order(self, sid, amount):
order = zp.ndict({
'dt':self.current_dt,
'sid':sid,
'amount':amount
})
self.order_count += 1
self.perf.log_order(order)
self.txn_sim.add_open_order(order)
def queue_event(self, event):
if self.event_queue == None:
self.event_queue = []
self.event_queue.append(event)
def get_data(self):
for event in self.event_queue:
#alias the dt as datetime
event.datetime = event.dt
self.event_data[event['sid']] = event
self.event_queue = []
return self.event_data
-2
View File
@@ -1,9 +1,7 @@
from host import ComponentHost
from component import Component
from monitor import Monitor
__all__ = [
Component,
Monitor,
ComponentHost
]
+1 -1
View File
@@ -7,7 +7,7 @@ import logbook
from setproctitle import setproctitle
from signal import SIGHUP, SIGINT
from collections import OrderedDict, Counter
from collections import Counter
from zipline.protocol import (
CONTROL_PROTOCOL,