Merge pull request #76 from quantopian/union

Union
This commit is contained in:
Stephen Diehl
2012-07-06 13:35:41 -07:00
21 changed files with 727 additions and 350 deletions
+14 -7
View File
@@ -6,6 +6,7 @@ import pytz
from unittest2 import TestCase
from datetime import datetime, timedelta
from collections import defaultdict
from logbook.compat import LoggingHandler
from nose.tools import timed
@@ -25,16 +26,22 @@ EXTENDED_TIMEOUT = 90
allocator = AddressAllocator(1000)
class FinanceTestCase(TestCase):
leased_sockets = defaultdict(list)
def setUp(self):
#qutil.configure_logging()
self.zipline_test_config = {
'allocator':allocator,
'sid':133
'allocator' : allocator,
'sid' : 133,
'devel' : True
}
self.log_handler = LoggingHandler()
self.log_handler.push_application()
def tearDown(self):
self.log_handler.pop_application()
@timed(DEFAULT_TIMEOUT)
def test_factory_daily(self):
@@ -106,7 +113,7 @@ class FinanceTestCase(TestCase):
# non blocking. HUNCH: The trades are streaming through before the orders
# are placed.
@timed(EXTENDED_TIMEOUT)
#@timed(EXTENDED_TIMEOUT)
def test_orders(self):
# Simulation
@@ -135,7 +142,7 @@ class FinanceTestCase(TestCase):
)
@timed(DEFAULT_TIMEOUT)
#@timed(DEFAULT_TIMEOUT)
def test_aggressive_buying(self):
# Simulation
@@ -232,7 +239,7 @@ class FinanceTestCase(TestCase):
)
self.assertEqual(
zipline.sources['flat'].count,
zipline.sources['SpecificEquityTrades'].count,
self.zipline_test_config['trade_count'],
"The simulated trade source should send all trades."
)
@@ -243,7 +250,7 @@ class FinanceTestCase(TestCase):
"The algorithm should receive all trades."
)
@timed(DEFAULT_TIMEOUT)
#@timed(DEFAULT_TIMEOUT)
def test_sid_filter(self):
"""Ensure the algorithm's filter prevents events from arriving."""
# create a test algorithm whose filter will not match any of the
+43
View File
@@ -0,0 +1,43 @@
import gevent
from logbook.compat import LoggingHandler
from unittest2 import TestCase, skip
from zipline.core.monitor import Controller
class TestMonitor(TestCase):
def setUp(self):
self.log_handler = LoggingHandler()
self.log_handler.push_application()
def tearDown(self):
self.log_handler.pop_application()
def test_init(self):
pub_socket = 'tcp://127.0.0.1:5000'
route_socket = 'tcp://127.0.0.1:5001'
con = Controller(pub_socket, route_socket)
con.manage([])
def test_init_topology(self):
pub_socket = 'tcp://127.0.0.1:5000'
route_socket = 'tcp://127.0.0.1:5001'
con = Controller(pub_socket, route_socket, )
con.manage([ 'a', 'b', 'c', 'd' ])
@skip
def test_poll(self):
from mock_zmq import zmq_synthetic
pub_socket = 'tcp://127.0.0.1:5000'
route_socket = 'tcp://127.0.0.1:5001'
cancel_socket = 'tcp://127.0.0.1:5002'
con = Controller(pub_socket, route_socket, cancel_socket)
con.manage([ 'a', 'b', 'c', 'd' ])
con.zmq = zmq_synthetic
con.zmq_flavor = 'green'
con.period = 0.00001
gevent.spawn(con.run).join(timeout=con.period)
+10 -4
View File
@@ -16,11 +16,13 @@ class PerformanceTestCase(unittest.TestCase):
self.benchmark_returns, self.treasury_curves = \
factory.load_market_data()
random_index = random.randint(
0,
len(self.treasury_curves)
)
for n in range(100):
random_index = random.randint(
0,
len(self.treasury_curves)
)
self.dt = self.treasury_curves.keys()[random_index]
self.end_dt = self.dt + datetime.timedelta(days=365)
@@ -29,6 +31,10 @@ class PerformanceTestCase(unittest.TestCase):
if self.end_dt <= now:
break
assert self.end_dt <= now, """
failed to find a date suitable daterange after 100 attempts. please double
check treasury and benchmark data in findb, and re-run the test."""
self.trading_environment = TradingEnvironment(
self.benchmark_returns,
self.treasury_curves,
+16 -2
View File
@@ -2,6 +2,8 @@ from datetime import timedelta
from collections import defaultdict
from unittest2 import TestCase
from logbook.compat import LoggingHandler
import zipline.utils.factory as factory
from zipline.finance.vwap import DailyVWAP, VWAPTransform
from zipline.finance.returns import ReturnsFromPriorClose
@@ -19,9 +21,15 @@ class ZiplineWithTransformsTestCase(TestCase):
allocator.lease(100)
self.trading_environment = factory.create_trading_environment()
self.zipline_test_config = {
'allocator':allocator,
'sid':133
'allocator' : allocator,
'sid' : 133,
'devel' : True
}
self.log_handler = LoggingHandler()
self.log_handler.push_application()
def tearDown(self):
self.log_handler.pop_application()
def test_vwap_tnfm(self):
zipline = SimulatedTrading.create_test_zipline(
@@ -36,8 +44,14 @@ class ZiplineWithTransformsTestCase(TestCase):
self.assertFalse(zipline.sim.exception)
class FinanceTransformsTestCase(TestCase):
def setUp(self):
self.trading_environment = factory.create_trading_environment()
self.log_handler = LoggingHandler()
self.log_handler.push_application()
def tearDown(self):
self.log_handler.pop_application()
def test_vwap(self):
+43 -52
View File
@@ -12,12 +12,12 @@ Abstract base class for Feed and Merge.
import logbook
import zipline.protocol as zp
from zipline.core.component import Component
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \
CONTROL_FRAME, CONTROL_UNFRAME
from zipline.utils.protocol_utils import Enum
from zipline.core.controlled import do_handle_control_events
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE
from zipline.transitions import WorkflowMeta
from zipline.utils.protocol_utils import Enum
log = logbook.Logger('Aggregate')
@@ -34,6 +34,10 @@ AGGREGATE_TRANSITIONS = dict(
do_drain = (READY , DRAINING) ,
)
# =========
# Component
# =========
class Aggregate(Component):
"""
Abstract superclass to Merge & Feed. Acts on two sockets
@@ -41,7 +45,7 @@ class Aggregate(Component):
- pull_socket
- feed_socket
Both use ``data_buffer`` for buffering.
Both use ``sources`` for buffering.
Feed and Merge define these differently.
"""
@@ -53,6 +57,9 @@ class Aggregate(Component):
def get_type(self):
return COMPONENT_TYPE.CONDUIT
def add_source(self, source_id):
self.sources[source_id] = []
# -------------
# Core Methods
# -------------
@@ -61,39 +68,26 @@ class Aggregate(Component):
# wait for synchronization reply from the host
socks = dict(self.poll.poll(self.heartbeat_timeout))
# TODO: Abstract this out, maybe on base component
if socks.get(self.control_in) == self.zmq.POLLIN:
msg = self.control_in.recv()
event, payload = CONTROL_UNFRAME(msg)
# -- Heartbeat --
if event == CONTROL_PROTOCOL.HEARTBEAT:
# Heart outgoing
heartbeat_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.OK,
payload
)
self.control_out.send(heartbeat_frame)
# -- Soft Kill --
elif event == CONTROL_PROTOCOL.SHUTDOWN:
self.signal_done()
self.shutdown()
# -- Hard Kill --
elif event == CONTROL_PROTOCOL.KILL:
self.kill()
# ----------------
# Control Dispatch
# ----------------
#do_handle_control_events(self, socks)
# -------------
# Work Dispatch
# -------------
if socks.get(self.pull_socket) == self.zmq.POLLIN:
message = self.pull_socket.recv()
if message == str(CONTROL_PROTOCOL.DONE):
self.ds_finished_counter += 1
if len(self.data_buffer) == self.ds_finished_counter:
#drain any remaining messages in the buffer
if len(self.sources) == self.ds_finished_counter:
# Drain any remaining messages in the buffer
log.debug("Draining Feed")
self.state = DRAINING
self.drain()
self.signal_done()
else:
@@ -105,7 +99,15 @@ class Aggregate(Component):
try:
self.append(event)
self.send_next()
if not (self.is_full() or self.draining):
event = self.next()
if event:
self.send(event)
else:
pass
except zp.INVALID_DATASOURCE_FRAME as exc:
# Invalid message
return self.signal_exception(exc)
@@ -118,30 +120,25 @@ class Aggregate(Component):
"""
Send all messages in the buffer.
"""
self.state = DRAINING
while self.pending_messages() > 0:
self.send_next()
event = self.next()
if event:
self.send(event)
def send_next(self):
def send(self, event):
"""
Send the (chronologically) next message in the buffer.
"""
if not (self.is_full() or self.draining):
return
event = self.next()
if event:
self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK)
self.sent_counters[event.source_id] += 1
self.sent_count += 1
self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK)
self.sent_counters[event.source_id] += 1
self.sent_count += 1
def is_full(self):
"""
Indicates whether the buffer has messages in buffer for all
un-DONE, blocking sources.
"""
for source_id, events in self.data_buffer.iteritems():
for source_id, events in self.sources.iteritems():
if len(events) == 0:
return False
return True
@@ -152,19 +149,13 @@ class Aggregate(Component):
buffer.
"""
total = 0
for events in self.data_buffer.itervalues():
for events in self.sources.itervalues():
total += len(events)
return total
def add_source(self, source_id):
"""
Add a data source to the buffer.
"""
self.data_buffer[source_id] = []
def __len__(self):
"""
Buffer's length is same as internal map holding separate
sorted arrays of events keyed by source id.
"""
return len(self.data_buffer)
return len(self.sources)
-4
View File
@@ -2,14 +2,10 @@
Commonly used messaging components.
"""
import logging
import zipline.protocol as zp
from zipline.core.component import Component
from zipline.protocol import COMPONENT_TYPE
LOGGER = logging.getLogger('ZiplineLogger')
class DataSource(Component):
"""
Abstract baseclass for data sources. Subclass and implement send_all
+20 -17
View File
@@ -1,5 +1,5 @@
import logbook
from collections import Counter
from collections import defaultdict, Counter
from zipline.components.aggregator import Aggregate, \
AGGREGATE_STATES, AGGREGATE_TRANSITIONS
@@ -28,7 +28,7 @@ class Feed(Aggregate):
self.received_count = 0
self.ds_finished_counter = 0
self.data_buffer = {}
self.sources = defaultdict(list)
# source_id -> integer count
self.sent_counters = Counter()
@@ -71,7 +71,7 @@ class Feed(Aggregate):
Add an event to the buffer for the source specified by
source_id.
"""
self.data_buffer[event.source_id].append(event)
self.sources[event.source_id].append(event)
self.recv_counters[event.source_id] += 1
self.received_count += 1
@@ -84,24 +84,27 @@ class Feed(Aggregate):
if not(self.is_full() or self.draining):
return
cur_source = None
earliest_source = None
earliest_event = None
#iterate over the queues of events from all sources
#(1 queue per datasource)
for events in self.data_buffer.itervalues():
if len(events) == 0:
continue
cur_source = events
first_in_list = events[0]
if first_in_list.dt == None:
#this is a filler event, discard
events.pop(0)
# iterate over the queues of source from all sources
# (1 queue per datasource)
for source in self.sources.itervalues():
if len(source) == 0:
continue
if (earliest_event == None) or (first_in_list.dt <= earliest_event.dt):
earliest_event = first_in_list
earliest_source = cur_source
head = source[0]
if head.dt == None:
#this is a filler event, discard
source.pop(0)
continue
if (earliest_event == None) or (head.dt <= earliest_event.dt):
earliest_event = head
earliest_source = source
if earliest_event != None:
return earliest_source.pop(0)
else:
return False
+7 -7
View File
@@ -2,7 +2,7 @@ import zipline.protocol as zp
from zipline.components.aggregator import Aggregate, \
AGGREGATE_STATES, AGGREGATE_TRANSITIONS
from collections import Counter
from collections import defaultdict, Counter
class Merge(Aggregate):
"""
@@ -19,9 +19,7 @@ class Merge(Aggregate):
self.draining = False
self.ds_finished_counter = 0
# Depending on the size of this, might want to use a data
# structure with better asymptotics.
self.data_buffer = {}
self.sources = defaultdict(list)
# source_id -> integer count
self.sent_counters = Counter()
@@ -61,7 +59,7 @@ class Merge(Aggregate):
source_id.
"""
self.data_buffer[event.keys()[0]].append(event)
self.sources[event.keys()[0]].append(event)
self.received_count += 1
def next(self):
@@ -73,8 +71,10 @@ class Merge(Aggregate):
return
#get the raw event from the passthrough transform.
result = self.data_buffer[zp.TRANSFORM_TYPE.PASSTHROUGH].pop(0).PASSTHROUGH
for source, events in self.data_buffer.iteritems():
passthrough = self.sources[zp.TRANSFORM_TYPE.PASSTHROUGH]
result = passthrough.pop(0).PASSTHROUGH
for source, events in self.sources.iteritems():
if source == zp.TRANSFORM_TYPE.PASSTHROUGH:
continue
if len(events) > 0:
+117 -58
View File
@@ -19,6 +19,7 @@ import gevent_zeromq
# zmq_ctypes
#import zmq_ctypes
from zipline.protocol import CONTROL_UNFRAME
from zipline.utils.gpoll import _Poller as GeventPoller
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_STATE, \
COMPONENT_FAILURE, CONTROL_FRAME
@@ -28,7 +29,7 @@ log = logbook.Logger('Component')
from zipline.exceptions import ComponentNoInit
from zipline.transitions import WorkflowMeta
# LOGBOOK - embed PID in log output
log = logbook.Logger('Base')
class Component(object):
@@ -39,11 +40,6 @@ class Component(object):
:param addresses: a dict of name_string -> zmq port address strings.
Must have the following entries
:param sync_address: socket address used for synchronizing the start of
all workers, heartbeating, and exit notification
will be used in REP/REQ sockets. Bind is always on
the REP side.
:param data_address: socket address used for data sources to stream
their records. Will be used in PUSH/PULL sockets
between data sources and a Feed. Bind will always
@@ -82,12 +78,15 @@ class Component(object):
self.zmq = None
self.context = None
self.addresses = None
self.waiting = None
self.out_socket = None
self.killed = False
self.controller = None
# timeout after a full minute
self.heartbeat_timeout = 60 *1000
# TODO: state_flag is deprecated, remove
# TODO: error_state is deprecated, remove
self.state_flag = COMPONENT_STATE.OK
self.error_state = COMPONENT_FAILURE.NOFAILURE
self.on_done = None
@@ -98,6 +97,7 @@ class Component(object):
self.stop_tic = None
self.note = None
self.confirmed = False
self.devel = False
# Humanhashes make this way easier to debug because they stick
# in your mind unlike a 32 byte string of random hex.
@@ -190,6 +190,12 @@ class Component(object):
raise Exception("Unknown ZeroMQ Flavor")
def _run(self):
"""
The main component loop. This is wrapped inside a
exception reporting context inside of run.
The core logic of the all components is run here.
"""
self.start_tic = time.time()
self.done = False # TODO: use state flag
@@ -200,11 +206,20 @@ class Component(object):
self.setup_poller()
self.open()
self.setup_sync()
self.setup_control()
self.signal_ready()
self.lock_ready()
self.wait_ready()
# -----------------------
# YOU SHALL NOT PASS!!!!!
# -----------------------
# ... until the controller signals GO
self.loop()
self.shutdown()
log.info("Shutdown %r" % self)
self.stop_tic = time.time()
@@ -246,20 +261,8 @@ class Component(object):
Loop to do work while we still have work to do.
"""
while self.working():
self.confirm()
self.do_work()
def confirm(self):
"""
Send a synchronization request to the host.
"""
if not self.confirmed:
# TODO: proper framing
self.sync_socket.send(self.get_id + ":RUN")
self.receive_sync_ack() # blocking
self.confirmed = True
def runtime(self):
if self.ready() and self.start_tic and self.stop_tic:
return self.stop_tic - self.start_tic
@@ -299,6 +302,81 @@ class Component(object):
# Internal Maintenance
# ----------------------
def lock_ready(self):
"""
Unlock the component, topology is now ready to run.
"""
self.waiting = True
def unlock_ready(self):
"""
Unlock the component, topology is still pending.
"""
self.waiting = False
def wait_ready(self):
# Implicit side-effect of unlocking the component iff
# the GO message is received from the monitor level.
# This then unlocks the barrier and proceeds to the
# do_work state.
# Poll on a subset of the control protocol while we exist
# in the locked quasimode. Respond to HEARTBEAT and GO
# messages.
while self.waiting:
socks = dict(self.poll.poll(self.heartbeat_timeout))
msg = self.control_in.recv()
event, payload = CONTROL_UNFRAME(msg)
# ====
# Go
# ====
# A distributed lock from the controller to ensure
# synchronized start.
if event == CONTROL_PROTOCOL.HEARTBEAT:
heartbeat_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.OK,
payload
)
self.control_out.send(heartbeat_frame)
log.info('Prestart Heartbeat ' + self.get_id)
elif event == CONTROL_PROTOCOL.GO:
# Side effectful call from the controller to unlock
# and begin doing work only when the entire topology
# of the system beings to come online
log.info('Unlocking ' + self.__class__.__name__)
self.unlock_ready()
def signal_ready(self):
log.info(self.__class__.__name__ + ' is ready')
if hasattr(self, 'control_out'):
frame = CONTROL_FRAME(
CONTROL_PROTOCOL.READY,
''
)
self.control_out.send(frame)
def signal_cancel(self):
self.done = True
# TODO: no hasattr hacks
#if not self.controller:
if hasattr(self, 'control_out'):
frame = CONTROL_FRAME(
CONTROL_PROTOCOL.SHUTDOWN,
None
)
self.control_out.send(frame)
# then proceeds to do shutdown(), and teardown_sockets()
# to complete the process
def signal_exception(self, exc=None, scope=None):
"""
This is *very* important error tracking handler.
@@ -320,7 +398,7 @@ class Component(object):
self._exception = exc
exc_type, exc_value, exc_traceback = sys.exc_info()
trace = '\n>>>'.join(traceback.format_exception(exc_type, exc_value, exc_traceback))
trace = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback))
sys.stdout.write(trace)
if hasattr(self, 'control_out'):
@@ -342,10 +420,6 @@ class Component(object):
if self.out_socket:
self.out_socket.send(str(CONTROL_PROTOCOL.DONE))
#notify host we're done
# TODO: proper framing
self.sync_socket.send(self.get_id + ":" + str(CONTROL_PROTOCOL.DONE))
#notify controller we're done
done_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.DONE,
@@ -353,7 +427,6 @@ class Component(object):
)
self.control_out.send(done_frame)
self.receive_sync_ack()
#notify internal work look that we're done
self.done = True # TODO: use state flag
@@ -373,19 +446,6 @@ class Component(object):
# ZeroMQ. Either zmq.Poller or gpoll.Poller .
self.poll = self.zmq_poller()
def receive_sync_ack(self):
"""
Wait for synchronization reply from the host.
DEPRECATED, left in for compatability for now.
"""
socks = dict(self.sync_poller.poll(self.heartbeat_timeout))
if self.sync_socket in socks and socks[self.sync_socket] == self.zmq.POLLIN:
message = self.sync_socket.recv()
#else:
#raise Exception("Sync ack timed out on response for {id}".format(id=self.get_id))
def bind_data(self):
return self.bind_pull_socket(self.addresses['data_address'])
@@ -405,10 +465,27 @@ class Component(object):
return self.connect_push_socket(self.addresses['merge_address'])
def bind_result(self):
return self.bind_pub_socket(self.addresses['result_address'])
return self.bind_push_socket(self.addresses['result_address'])
def connect_result(self):
return self.connect_sub_socket(self.addresses['result_address'])
return self.connect_pull_socket(self.addresses['result_address'])
def bind_push_socket(self, addr):
push_socket = self.context.socket(self.zmq.PUSH)
push_socket.bind(addr)
self.out_socket = push_socket
self.sockets.append(push_socket)
return push_socket
def connect_pull_socket(self, addr):
pull_socket = self.context.socket(self.zmq.PULL)
pull_socket.connect(addr)
self.sockets.append(pull_socket)
self.poll.register(pull_socket, self.zmq.POLLIN)
return pull_socket
def bind_pull_socket(self, addr):
pull_socket = self.context.socket(self.zmq.PULL)
@@ -470,24 +547,6 @@ class Component(object):
self.poll.register(self.control_in, self.zmq.POLLIN)
self.sockets.extend([self.control_in, self.control_out])
def setup_sync(self):
"""
Setup the sync socket and poller. ( Connect )
DEPRECATED, left in for compatability for now.
"""
#LOGGER.debug("Connecting sync client for {id}".format(id=self.get_id))
self.sync_socket = self.context.socket(self.zmq.REQ)
self.sync_socket.connect(self.addresses['sync_address'])
#self.sync_socket.setsockopt(self.zmq.LINGER,0)
self.sync_poller = self.zmq_poller()
self.sync_poller.register(self.sync_socket, self.zmq.POLLIN)
self.sockets.append(self.sync_socket)
# -----------
# FSM Actions
# -----------
+74
View File
@@ -0,0 +1,74 @@
"""
Poller logic for a component which is controlled by the monitor, this is
largely universal and thus we break it out into a seperate module and
splice it into the dispatch loops for each component instance.
Example usage::
def do_work():
socks = self.poll.poll()
# Handle control events
do_handle_control_events()
# Handle other events
if socks.get(socket) == zmq.POLLIN:
...
"""
import zmq
from zipline.core.component import Component
from zipline.protocol import CONTROL_PROTOCOL, CONTROL_FRAME, CONTROL_UNFRAME
def do_handle_control_events(cls, poller):
assert isinstance(cls, Component)
assert cls.control_in, 'Component does not have a control_in socket'
# If we're in devel mode drop out because the controller
# isn't guaranteed to be around anymore
if cls.devel:
return
if poller.get(cls.control_in) == zmq.POLLIN:
msg = cls.control_in.recv()
event, payload = CONTROL_UNFRAME(msg)
# ===========
# Heartbeat
# ===========
# The controller will send out a single number packed in
# a CONTROL_FRAME with ``heartbeat`` event every
# (n)-seconds. The component then has n seconds to
# respond to it. If not then it will be considered as
# malfunctioning or maybe CPU bound.
if event == CONTROL_PROTOCOL.HEARTBEAT:
# Heart outgoing
heartbeat_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.OK,
payload
)
# Echo back the heartbeat identifier to tell the
# controller that this component is still alive and
# doing work
cls.control_out.send(heartbeat_frame)
# =========
# Soft Kill
# =========
# Try and clean up properly and send out any reports or
# data that are done during a clean shutdown. Inform the
# controller that we're done.
elif event == CONTROL_PROTOCOL.SHUTDOWN:
cls.signal_done()
cls.shutdown()
# =========
# Hard Kill
# =========
# Just exit.
elif event == CONTROL_PROTOCOL.KILL:
cls.kill()
+11
View File
@@ -3,9 +3,18 @@ Simulator hosts all the components necessary to execute a simulation.
See :py:method""
"""
import logbook
import threading
from zipline.core.simulatorref import SimulatorBase
log = logbook.Logger('Dev Simulator')
DEPRECATION_WARNING = """
WARNING WARNING WARNING
THE DEVSIMULATOR IS DEPRECATED, IT WILL NOT BEHAVE LIKE ANY OTHER
SYSTEM USED IN TESTS OR IN PRODUCTION
"""
class AddressAllocator(object):
"""
Produces a iterator of 10000 sockets to allocate as needed.
@@ -38,6 +47,8 @@ class Simulator(SimulatorBase):
self.subthreads = []
self.running = False
log.warn(DEPRECATION_WARNING)
@property
def get_id(self):
return 'Simple Simulator'
+35 -59
View File
@@ -1,26 +1,21 @@
import os
import sys
import logbook
import datetime
from component import Component
from zipline.transforms import BaseTransform
from zipline.components import Feed, Merge, PassthroughTransform, \
DataSource
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_STATE
log = logbook.Logger('Host')
log = logbook.Logger('Topology')
class ComponentHost(Component):
class ComponentHost(object):
"""
Components that can launch multiple sub-components, synchronize
their start, and then wait for all components to be finished.
"""
def init(self, addresses):
assert hasattr(self, 'zmq_flavor'), \
""" You must specify a flavor of ZeroMQ for all ComponentHost
subclasses. """
def __init__(self, addresses):
self.addresses = addresses
self.running = False
@@ -32,17 +27,30 @@ class ComponentHost(Component):
self._components = {}
# ----------------------
self.sync_register = {}
self.timeout = datetime.timedelta(seconds=60)
self.exception = None
self.feed = Feed()
self.merge = Merge()
self.passthrough = PassthroughTransform()
self.controller = None
#register the feed and the merge
self.register_components([self.feed, self.merge, self.passthrough])
def _run(self):
self.open()
def run(self, catch_exceptions=True):
"""
Run the host.
"""
log.info('===== PARENT PID: %s' % os.getppid())
self.open()
#self.shutdown()
def shutdown(self, ensure_clean=True):
raise NotImplementedError
def register_controller(self, controller):
"""
Add the given components to the registry. Establish
@@ -74,33 +82,26 @@ class ComponentHost(Component):
self._components[component.guid] = component
self.components[component.get_id] = component
self.sync_register[component.get_id] = datetime.datetime.utcnow()
if isinstance(component, DataSource):
self.feed.add_source(component.source_id)
self.feed.add_source(component.get_id)
if isinstance(component, BaseTransform):
self.merge.add_source(component.get_id)
def unregister_component(self, component_id):
del self.components[component_id]
del self.sync_register[component_id]
def setup_sync(self):
"""
Setup the sync socket and poller. ( Bind )
"""
#log.debug("Connecting sync server.")
self.sync_socket = self.context.socket(self.zmq.REP)
self.sync_socket.bind(self.addresses['sync_address'])
self.sync_poller = self.zmq_poller()
self.sync_poller.register(self.sync_socket, self.zmq.POLLIN)
self.sockets.append(self.sync_socket)
def open(self):
assert hasattr(self, 'zmq_flavor'), \
""" You must specify a flavor of ZeroMQ for all Topology
subclasses. """
log.info('== Roll Call ==')
log.info('Controller')
self.launch_controller()
for component in self.components.itervalues():
log.info(component)
@@ -109,8 +110,6 @@ class ComponentHost(Component):
for component in self.components.itervalues():
self.launch_component(component)
self.launch_controller()
def is_running(self):
"""
DEPRECATED, left in for compatability for now.
@@ -122,38 +121,15 @@ class ComponentHost(Component):
return True
def loop(self, lockstep=True):
while self.is_running():
# wait for synchronization request at start, and DONE at end.
# don't timeout.
socks = dict(self.sync_poller.poll())
if socks.get(self.sync_socket) == self.zmq.POLLIN:
msg = self.sync_socket.recv()
try:
parts = msg.split(':')
sync_id, status = parts
except ValueError as exc:
self.signal_exception(exc)
# TODO: other way around
if status == str(CONTROL_PROTOCOL.DONE):
#log.debug("Component claims done: {id}".format(id=sync_id))
self.unregister_component(sync_id)
self.state_flag = COMPONENT_STATE.DONE
else:
self.sync_register[sync_id] = datetime.datetime.utcnow()
#log.info("confirmed {id}".format(id=msg))
# send synchronization reply
self.sync_socket.send('ack', self.zmq.NOBLOCK)
def ready(self):
return True
# ------------------
# Simulation Control
# ------------------
# Overloaded by simulator
def launch_controller(self, controller):
raise NotImplementedError
+228 -106
View File
@@ -1,9 +1,12 @@
import os
import zmq
import sys
import time
import gevent
import itertools
import logbook
import gevent_zeromq
from signal import SIGHUP, SIGINT
from collections import OrderedDict
@@ -11,14 +14,17 @@ from zipline.utils.gpoll import _Poller as GeventPoller
from zipline.protocol import CONTROL_PROTOCOL, CONTROL_FRAME, \
CONTROL_UNFRAME, CONTROL_STATES, INVALID_CONTROL_FRAME \
from zipline.utils.protocol_utils import ndict
INIT, SOURCES_READY, RUNNING, TERMINATE = CONTROL_STATES
CONTROLLER_TRANSITIONS = frozenset([
(-1 , INIT),
(INIT , SOURCES_READY),
(SOURCES_READY , RUNNING),
(INIT , TERMINATE),
(SOURCES_READY , TERMINATE),
(INIT , TERMINATE), # pseudo failure mode
(SOURCES_READY , TERMINATE), # pseudo failure mode
(RUNNING , TERMINATE),
])
@@ -32,6 +38,18 @@ class UnknownChatter(Exception):
log = logbook.Logger('Controller')
# The scalars determining the timing of the monitor behavior for
# the system.
PARAMETERS = ndict(dict(
GENERATIONAL_PERIOD = 1,
ALLOWED_SKIPPED_HEARTBEATS = 3,
ALLOWED_INVALID_HEARTBEATS = 3,
PRESTART_HEARBEATS = 3,
SOURCES_START_HEARTBEATS = 3,
SYSTEM_TIMEOUT = 50,
))
class Controller(object):
"""
A N to M messaging system for inter component communication.
@@ -43,38 +61,25 @@ class Controller(object):
the individual components.
:func message_sender: .
Topology is the set of components we expect to show up.
States are the transitions the sytems go through. The
simplest is from RUNNING -> NOT RUNNING .
Usage::
controller = Controller(
'tcp://127.0.0.1:5000',
'tcp://127.0.0.1:5001',
)
# typically you'd want to run this async to your main
# program since it blocks indefinetely.
controller.manage(
[ TOPOLOGY ]
[ STATES ]
)
"""
debug = False
period = 1
# Turn on debug for verbose logging of the system.
debug = True
period = PARAMETERS.GENERATIONAL_PERIOD
def __init__(self, pub_socket, route_socket):
def __init__(self, pub_socket, route_socket, devel=True):
self.devel = devel
self.nosignals = False
self.context = None
self.zmq = None
self.zmq_poller = None
self.running = False
self.polling = False
self.alive = False
self.tracked = set()
self.finished = set()
self.responses = set()
self.ctime = 0
@@ -89,6 +94,8 @@ class Controller(object):
self.error_replay = OrderedDict()
log.warn("Running Controller in development mode, will ONLY synchronize start.")
def init_zmq(self, flavor):
assert self.zmq_flavor in ['thread', 'mp', 'green']
@@ -97,6 +104,8 @@ class Controller(object):
self.zmq = zmq
self.context = self.zmq.Context()
self.zmq_poller = self.zmq.Poller
log.warning("USING DEVELOPMENT MODE IN MP CONTEXT NOT RECOMMENDED")
return
if flavor == 'thread':
self.zmq = zmq
@@ -114,7 +123,7 @@ class Controller(object):
self.zmq_poller = self.zmq.Poller
return
def manage(self, topology, states=None, context=None):
def manage(self, topology):
"""
Give the controller a set set of components to manage and
a set of state transitions for the entire system.
@@ -128,31 +137,43 @@ class Controller(object):
else:
self.freeform = False
self.topology = frozenset(topology)
self.polling = True
self.state = CONTROL_STATES.INIT
self.alive = True
@property
def state(self):
#log.info('returned %s' % self._state)
return self._state
@state.setter
def state(self, new):
old, self._state = self._state, new
old = self._state
if (old, new) not in CONTROLLER_TRANSITIONS:
raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new))
if (old, new) in CONTROLLER_TRANSITIONS:
self._state = new
log.info("State Transition : %s -> %s" % (old, self._state))
else:
log.info("State Transition : %s -> %s" %(old, new))
raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new))
def run(self):
self.running = True
self.init_zmq(self.zmq_flavor)
self.state = CONTROL_STATES.INIT
# Interpreter SIDE EFFECT
# -----------------------
# The last breathe of the interpreter will assume that we've
# failed unless we specify otherwise.
if not self.devel:
sys.exitfunc = self.signal_interrupt
# We overload this if ( and only if ) the topology exits
# cleanly. This prevents failure modes where the monitor
# dies.
try:
return self._poll() # use a python loop
except KeyboardInterrupt:
log.debug('Shutdown event loop')
log.info('Shutdown event loop')
def log_status(self):
"""
@@ -172,6 +193,13 @@ class Controller(object):
# Publications
# -------------
def send_go(self):
go_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.GO,
''
)
self.pub.send(go_frame)
def send_heart(self):
if not self.running:
return
@@ -210,45 +238,55 @@ class Controller(object):
assert self.route_socket
assert self.pub_socket
#assert self.cancel_socket
assert self.topology,\
""""Must define topology to monitor, call setup_controller() on
your Zipline. """
# -- Publish --
# =============
self.pub = self.context.socket(self.zmq.PUB)
self.pub.bind(self.pub_socket)
# -- Cancel --
# =============
#assert isinstance(self.cancel_socket,basestring), self.cancel_socket
#self.cancel = self.context.socket(self.zmq.REP)
#self.cancel.connect(self.cancel_socket)
self.pub.setsockopt(zmq.LINGER, 0)
# -- Router --
# =============
self.router = self.context.socket(self.zmq.ROUTER)
self.router.bind(self.route_socket)
self.router.setsockopt(zmq.LINGER, 0)
poller = self.zmq.Poller()
poller.register(self.router, self.zmq.POLLIN)
#poller.register(self.cancel, self.zmq.POLLIN)
self.associated += [self.pub, self.router]# self.cancel]
self.associated += [self.pub, self.router]
# TODO: actually do this
self.state = CONTROL_STATES.SOURCES_READY
self.state = CONTROL_STATES.RUNNING
buffer = []
# ===================
# Heartbeat Iteration
# ===================
for i in itertools.count(0):
self.log_status()
# Reset the responses for this cycle
self.responses = set()
# broadcast the heartbeat packet
self.ctime = time.time()
self.send_heart()
while self.polling:
# Reset the responses for this cycle
# ==============
# Hearbeat Cycle
# ==============
# Wait the responses
while self.alive:
socks = dict(poller.poll(self.period))
tic = time.time()
@@ -270,28 +308,90 @@ class Controller(object):
log.error('Invalid frame', rawmessage)
pass
#if socks.get(self.cancel) == self.zmq.POLLIN:
# self.logging.info('[Controller] Received Cancellation')
# rawmessage = self.cancel.recv()
# self.cancel.send('')
# self.shutdown(soft=True)
# break
# ================
# Heartbeat Stats
# ================
self.beat()
complete = self.beat()
# ================
# Topology Status
# ================
# Is the entire topology told us its DONE
done = len(self.finished) == len(self.topology)
# Is the entire topology shown up to the party
complete = len(self.tracked) == len(self.topology)
if complete:
self.send_go()
# If we're running in development stop here
# because our responsibilites are over. The
# zipline will either run to completion or die,
# monitor doesn't care anymore because its all
# threads.
if self.devel:
log.warn("Shutting down Controller because in devel mode")
#sys.exitfunc = lambda: None
self.shutdown(soft=True)
log.info('Heartbeat (%s, %s)' % (done, complete))
# ================
# Exit Strategies
# ================
if self.zmq_flavor == 'green':
gevent.sleep(0)
if self.state is CONTROL_STATES.TERMINATE:
# Will also fall out of loop when done, if using
# non-freeform topology
if done:
log.info('Entire topology exited cleanly')
self.shutdown(soft=True)
# Noop exit func
#sys.exitfunc = lambda: None
# Send SIGHUP to buritto
self.signal_hangup()
if not self.alive:
break
if not self.polling:
break
def signal_hangup(self):
"""
A clean exit, inform the burrito ( and arbiter ) that
we're good. The topology exited cleanly and we can prove
it.
"""
if not self.nosignals:
ppid = os.getpid()
os.kill(ppid, SIGHUP)
else:
log.warning("Would SIGHUP here, but disabled")
# After loop exits
self.terminated = True
def signal_interrupt(self):
"""
Send a SIGINT in the error mode that the monitor's
interpreter exits. If the monitor dies the system is
considered a failure.
"""
if not self.nosignals:
ppid = os.getpid()
os.kill(ppid, SIGINT)
else:
log.warning("Would SIGINT here, but disabled")
def beat(self):
"""
The tracking logic of the system. It's the "stethoscope"
that inspects to the heartbeats in a generation and
infers the state of the system from the responses.
"""
# These the set overloaded operations
# A & B ~ set.intersection
@@ -303,24 +403,44 @@ class Controller(object):
# send us back a response.
# * new - Components we haven't heard from yet, but sent back the
# right response.
# * finished - Components we were tracking but have now
# finished, when this set goes to zero this
# triggers the end of the topology.
good = self.tracked & self.responses
bad = self.tracked - good
new = self.responses - good
missing = self.topology - self.tracked
for component in new:
self.new(component)
if self.debug:
log.info('New component %r' % component)
for component in bad:
self.fail(component)
if self.debug:
log.info('Bad component %r' % component)
if self.debug:
for component in missing:
log.info('Missing component %r' % component)
for component in self.tracked:
if component not in self.topology:
log.info('Uninvited component %r' % component)
# --------------
# Init Handlers
# --------------
def new_source(self):
if self.state is CONTROL_STATES.RUNNING:
self.state = SOURCES_READY
#if self.state is CONTROL_STATES.RUNNING:
#self.state = SOURCES_READY
pass
def new_universal(self):
pass
@@ -331,7 +451,11 @@ class Controller(object):
if self.state is CONTROL_STATES.TERMINATE:
return
log.info(' Now Tracking "%s" ' % component)
if component in self.finished:
#log.info("Got heartbeat from supposedly finished component")
return
log.info('Now Tracking "%s" ' % component)
universal = self.new_universal
init_handlers = {
@@ -342,7 +466,7 @@ class Controller(object):
init_handlers.get(component, universal)()
self.tracked.add(component)
else:
# Some sort of socket collision has occured, this is
# Some sort of socket collision has occurred, this is
# a very bad failure mode.
raise UnknownChatter(component)
@@ -353,8 +477,8 @@ class Controller(object):
def fail_universal(self):
pass
# TODO: this requires higher order functionality
#log.error('System in exception state, shutting down')
#self.shutdown(soft=True)
log.error('System in exception state, shutting down')
self.shutdown(soft=True)
def fail(self, component):
if self.state is CONTROL_STATES.TERMINATE:
@@ -364,16 +488,22 @@ class Controller(object):
fail_handlers = { }
if component in self.topology or self.freeform:
log.warning('Component "%s" timed out' % component)
log.warning('Component "%s" missed heartbeat' % component)
self.tracked.remove(component)
fail_handlers.get(component, universal)()
# TODO: determine when this propogates to a true
# failure, missing one heartbeat could just mean that
# its CPU overloaded
#fail_handlers.get(component, universal)()
# -------------------
# Completion Handling
# -------------------
def done(self, component):
log.info('Component "%s" signaled done.' % component)
self.finished.add(component)
self.tracked.discard(component)
log.info('Component "%s" finished.' % component)
# --------------
# Error Handling
@@ -393,6 +523,7 @@ class Controller(object):
if component in self.topology or self.freeform:
self.error_replay[(component, time.time())] = failure
log.error('Component in exception state: %s' % component)
log.error(str(failure))
exception_handlers.get(component, universal)()
else:
@@ -406,30 +537,48 @@ class Controller(object):
"""
Check for proper framing at the transport layer.
Seperates the proper frames from anything else that might
be coming over the wire. Which shouldn't happen ... right?
be coming over the wire.
"""
identity = msg[0]
identity = msg[0] # identity of the socket
id, status = CONTROL_UNFRAME(msg[1])
# A component is telling us its alive:
# I'm alive, condemned to be a free process in the cold
# cold dark absurd Zipline universe.
if id is CONTROL_PROTOCOL.READY:
self.responses.add(identity)
return
# The heartbeat love song between a component and the
# controller
if id is CONTROL_PROTOCOL.OK:
if status == str(self.ctime):
# Go to your bosom; knock there, and ask your heart what
# it doth know...
self.responses.add(identity)
elif status < self.ctime:
# False face must hide what the false heart doth know.
log.warning('Delayed heartbeat received.')
else:
# Otherwise its something weird and we don't know
# what to do so just say so, probably line noise
# from ZeroMQ
log.error("Weird stuff happened: %s" % msg)
# What's in a name? that which we call a rose...
log.error("Weird heartbeat packet happened: %s" % msg)
return
# A component is telling us it failed, and how
if id is CONTROL_PROTOCOL.EXCEPTION:
self.exception(identity, status)
return
# A component is telling us its done with work and won't
# be talking to us anymore
if id is CONTROL_PROTOCOL.DONE:
self.done(identity)
return
# -------------------
# Hooks for Endpoints
@@ -474,54 +623,27 @@ class Controller(object):
def do_error_replay(self):
for (component, time), error in self.error_replay.iteritems():
log.info('Error Log for -- %s --:\n%s' %
(component, error))
log.debug('Component Log for -- %s --:\n%s' % (component, error))
def shutdown(self, hard=False, soft=True, context=None):
def shutdown(self, hard=False, soft=True):
assert hard or soft, """ Must specify kill hard or soft """
if self.state is CONTROL_STATES.TERMINATE:
return
if not self.polling:
return
self.alive = False
self.polling = False
assert hard or soft, """ Must specify kill hard or soft """
if hard:
if hard and not self.devel:
self.state = CONTROL_STATES.TERMINATE
log.info('Hard Shutdown')
#for asoc in self.associated:
#asoc.close()
if soft:
if soft and not self.devel:
self.state = CONTROL_STATES.TERMINATE
log.info('Soft Shutdown')
self.send_softkill()
#for asoc in self.associated:
#asoc.close()
#self.do_error_replay()
self.do_error_replay()
if __name__ == '__main__':
print 'Running on '\
'tcp://127.0.0.1:5000 '\
'tcp://127.0.0.1:5001 '
controller = Controller(
'tcp://127.0.0.1:5000',
'tcp://127.0.0.1:5001',
)
controller.zmq_flavor = 'green'
controller.manage(
'freeform',
[]
)
controller.run()
#self.pub.close()
#self.router.close()
+3 -3
View File
@@ -36,14 +36,14 @@ Risk Report
"""
import logging
import logbook
import datetime
import math
import numpy as np
import numpy.linalg as la
from zipline.utils.date_utils import epoch_now
LOGGER = logging.getLogger('ZiplineLogger')
log = logbook.Logger('Risk')
def advance_by_months(dt, jump_in_months):
month = dt.month + jump_in_months
@@ -255,7 +255,7 @@ class RiskMetrics():
cur_return = math.log(1.0 + r) + cur_return
#this is a guard for a single day returning -100%
except ValueError:
LOGGER.warn("{cur} return, zeroing the returns".format(cur=cur_return))
log.warn("{cur} return, zeroing the returns".format(cur=cur_return))
cur_return = 0.0
compounded_returns.append(cur_return)
+8 -11
View File
@@ -21,13 +21,12 @@ import zipline.protocol as zp
class TradeDataSource(DataSource):
def init(self, source_id):
self.source_id = source_id
def init(self):
self.setup_source()
@property
def get_id(self):
return 'TradeDataSource'
#@property
#def get_id(self):
# return 'TradeDataSource'
def send(self, event):
"""
@@ -37,14 +36,14 @@ class TradeDataSource(DataSource):
:rtype: None
"""
event.source_id = self.source_id
event.source_id = self.get_id
if event.sid in self.filter['sid']:
message = zp.DATASOURCE_FRAME(event)
else:
blank = ndict({
"type" : zp.DATASOURCE_TYPE.TRADE,
"source_id" : self.source_id
"source_id" : self.get_id
})
message = zp.DATASOURCE_FRAME(blank)
@@ -56,8 +55,7 @@ class RandomEquityTrades(TradeDataSource):
Generates a random stream of trades for testing.
"""
def init(self, sid, source_id, count):
self.source_id = source_id
def init(self, sid, count):
self.count = count
self.incr = 0
self.sid = sid
@@ -95,7 +93,7 @@ class SpecificEquityTrades(TradeDataSource):
Generates a random stream of trades for testing.
"""
def init(self, source_id, event_list):
def init(self, event_list):
"""
:param event_list: should be a chronologically ordered list of
dictionaries in the following form::
@@ -107,7 +105,6 @@ class SpecificEquityTrades(TradeDataSource):
'volume' : integer for volume
}
"""
self.source_id = source_id
self.event_list = event_list
self.count = 0
+3 -3
View File
@@ -1,12 +1,12 @@
import pytz
import math
import logging
import logbook
import datetime
import zipline.protocol as zp
from zipline.protocol import SIMULATION_STYLE
LOGGER = logging.getLogger('ZiplineLogger')
log = logbook.Logger('Transaction Simulator')
class TransactionSimulator(object):
@@ -39,7 +39,7 @@ class TransactionSimulator(object):
log = "requested to trade zero shares of {sid}".format(
sid=event.sid
)
LOGGER.debug(log)
log.debug(log)
return
if not self.open_orders.has_key(event.sid):
+88 -9
View File
@@ -60,6 +60,8 @@ before invoking simulate.
+---------------------------------+
"""
import inspect
import logbook
import zipline.utils.factory as factory
from zipline.components import DataSource
@@ -71,6 +73,8 @@ from zipline.core.devsimulator import Simulator
from zipline.core.monitor import Controller
from zipline.finance.trading import SIMULATION_STYLE
log = logbook.Logger('Lines')
class SimulatedTrading(object):
"""
Zipline with::
@@ -113,6 +117,8 @@ class SimulatedTrading(object):
self.trading_environment = config['trading_environment']
self.sim_style = config.get('simulation_style')
self.devel = config.get('devel', False)
self.leased_sockets = []
self.sim_context = None
@@ -123,14 +129,15 @@ class SimulatedTrading(object):
'feed_address' : sockets[2],
'merge_address' : sockets[3],
'result_address' : sockets[4],
'order_address' : sockets[5]
'order_address' : sockets[5]
}
self.con = Controller(
sockets[6],
sockets[7],
devel = self.devel
)
# TODO: Not freeform
self.con.manage(
'freeform'
@@ -167,7 +174,7 @@ class SimulatedTrading(object):
def create_test_zipline(**config):
"""
:param config: A configuration object that is a dict with:
- environment - a \
:py:class:`zipline.finance.trading.TradingEnvironment`
- allocator - a :py:class:`zipline.simulator.AddressAllocator`
@@ -190,7 +197,7 @@ class SimulatedTrading(object):
a SIMULATION_STYLE as defined in :py:mod:`zipline.finance.trading`
"""
assert isinstance(config, dict)
allocator = config['allocator']
sid = config['sid']
@@ -266,12 +273,20 @@ class SimulatedTrading(object):
'allocator' : allocator,
'simulator_class' : simulator_class,
'simulation_style' : simulation_style,
'log_socket' : log_socket
'log_socket' : log_socket,
'devel' : config.get('devel', False)
})
#-------------------
zipline.add_source(trade_source)
# Save us from needless debugging
inside_test = 'nose' in inspect.stack()[-1][1]
if inside_test and not config.get('devel', False):
assert False, """
You need to run the SimulatedTrading inside a test with devel=True
"""
return zipline
def add_source(self, source):
@@ -285,7 +300,7 @@ class SimulatedTrading(object):
self.sim.register_components([source])
# ``id`` is name of source_id, ``get_id`` is the class name
self.sources[source.source_id] = source
self.sources[source.get_id] = source
def add_transform(self, transform):
assert isinstance(transform, BaseTransform)
@@ -324,12 +339,76 @@ class SimulatedTrading(object):
return leased
def simulate(self, blocking=False):
@property
def components(self):
"""
Return the component instances inside of this topology
"""
base = set(self.sim.components.values())
transforms = set(self.transforms.values())
sources = set(self.sources.values())
return base | transforms | sources
@property
def topology(self):
"""
Returns the Component names in the topology of the
backtest.
"""
# A complete topology is the union of three classes of
# components added individually to the simulation client
# at various places.
#
# base : ['FEED', 'MERGE', 'TRADING_CLIENT', 'PASSTHROUGH']
# transforms : ['vwap__01', ... ]
# sources : ['MongoTradeHistory', ... ]
base = set(self.sim.components.keys())
transforms = set(self.transforms.keys())
sources = set(self.sources.keys())
return base | transforms | sources
def setup_controller(self):
"""
Prepare the controller tro manage the topology specified
by this line.
"""
self.con.manage(self.topology)
def simulate(self, blocking=True):
self.setup_controller()
self.started = True
self.sim_context = self.sim.simulate()
if blocking:
self.sim_context.join()
# If we're in development mode then flag all the
# components in the topology as devel so as to indicate
# that they won't poll on the control channels for
# anything other than the synchronized start.
if self.devel:
for component in self.components:
component.devel = True
# If we're using a threaded simulator block on the pool
# of thread since we're only ever in a test and we don't
# generally monitor the state of the system as a hold at
# the supervisory layer
# TODO: better way of identifying concurrency substrate
if self.sim.zmq_flavor == 'thread':
log.debug('Blocking')
for thread in self.sim.subthreads:
#log.debug('Waiting on %r' % thread)
log.debug('Waiting on %r' % thread)
thread.join()
log.debug('Yielded on %r' % thread)
else:
for process in self.sim.subprocesses:
process.join()
@property
def is_success(self):
+2 -1
View File
@@ -65,7 +65,7 @@ def create_updown_trade_source(sid, trade_count, trading_environment, base_price
trading_environment.period_end = cur
source = SpecificEquityTrades("updown_" + str(sid), events)
source = SpecificEquityTrades(events)
return source
@@ -128,6 +128,7 @@ def create_predictable_zipline(config, offset=0, simulate=True):
config['trade_source'] = source
config['environment'] = trading_environment
config['simulation_style'] = SIMULATION_STYLE.FIXED_SLIPPAGE
config['devel'] = True
zipline = SimulatedTrading.create_test_zipline(**config)
+3 -2
View File
@@ -141,11 +141,12 @@ CONTROL_PROTOCOL = Enum(
'HEARTBEAT' , # 0 - req
'SHUTDOWN' , # 1 - req
'KILL' , # 2 - req
'GO' , # - req
'OK' , # 3 - rep
'DONE' , # 4 - rep
'EXCEPTION' , # 5 - rep
'SIGNAL' , # 6 - rep
'READY' , # 6 - rep
)
def CONTROL_FRAME(event, payload):
@@ -305,7 +306,7 @@ def FEED_FRAME(event):
- source_id
- type
"""
assert isinstance(event, ndict)
assert isinstance(event, ndict), 'unknown type %s' % str(event)
source_id = event.source_id
ds_type = event.type
PACK_DATE(event)
-3
View File
@@ -1,12 +1,9 @@
import logging
from zipline.core.component import Component
import zipline.protocol as zp
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \
CONTROL_FRAME, CONTROL_UNFRAME
LOGGER = logging.getLogger('ZiplineLogger')
class BaseTransform(Component):
"""
Top level execution entry point for the transform
+2 -2
View File
@@ -171,7 +171,7 @@ def create_returns_from_list(returns, trading_calendar):
def create_random_trade_source(sid, trade_count, trading_environment):
# create the source
source = RandomEquityTrades(sid, "rand-"+str(sid), trade_count)
source = RandomEquityTrades(sid, trade_count)
# make the period_end of trading_environment match
cur = trading_environment.first_open
@@ -244,5 +244,5 @@ def create_trade_source(sids, trade_count, trade_time_increment, trading_environ
#history.
trading_environment.period_end = trade_history[-1].dt
source = SpecificEquityTrades("flat", trade_history)
source = SpecificEquityTrades(trade_history)
return source