modified zmq_gen method to yield None when there is no waiting message. This prevents blocking in the next() method of a component. But it requires generators wrapping the component to handle None.

Also modified component's receiver creation to be triggered on the first call to next, rather than iter.
This change means that the zmq context and socket for the component's receiver should always be created in
the same process as the consumer of the generator. Chaining together component wrapped generators will
result in the send process of the last component actually instantiating the receive socket of the prior component.
In this way, the components are actually communicating directly via zmq.

Component's send method now calls the wait_ready(), which waits for the monitor's GO message, inside
the generator loop. This guarantees that the generator's next method is called before the send loop blocks
on the monitor. As a result, components will call __init__ and next() without blocking, mimicking the
behavior of plain generators.
This commit is contained in:
fawce
2012-08-04 12:58:07 -04:00
parent 4a3773848a
commit 4a582e8952
5 changed files with 228 additions and 41 deletions
+178 -8
View File
@@ -1,13 +1,20 @@
import zmq
import pytz
from pprint import pformat as pf
from datetime import datetime, timedelta
from unittest2 import TestCase
from collections import defaultdict
from zipline.gens.composites import date_sorted_sources
from zipline.gens.composites import date_sorted_sources, merged_transforms
from zipline.finance.trading import SIMULATION_STYLE
from zipline.core.devsimulator import AddressAllocator
from zipline.gens.transform import MovingAverage, Passthrough, StatefulTransform
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
from zipline.utils.factory import create_trading_environment
from zipline.test_algorithms import TestAlgorithm
from zipline.utils.test_utils import (
setup_logger,
@@ -19,7 +26,12 @@ from zipline.utils.test_utils import (
from zipline.core import Component
from zipline.protocol import (
DATASOURCE_FRAME,
DATASOURCE_UNFRAME
DATASOURCE_UNFRAME,
FEED_FRAME,
FEED_UNFRAME,
MERGE_FRAME,
MERGE_UNFRAME,
SIMULATION_STYLE
)
from zipline.gens.tradegens import SpecificEquityTrades
@@ -65,9 +77,7 @@ class ComponentTestCase(TestCase):
}
trade_gen = SpecificEquityTrades(*args_a, **kwargs_a)
monitor.add_to_topology(trade_gen.get_hash())
launch_monitor(monitor)
comp_a = Component(
trade_gen,
@@ -77,9 +87,14 @@ class ComponentTestCase(TestCase):
DATASOURCE_UNFRAME
)
launch_monitor(monitor)
for event in comp_a:
log.info(event)
# wait for the sending process to exit
comp_a.proc.join()
def test_sort(self):
monitor = create_monitor(allocator)
@@ -97,7 +112,6 @@ class ComponentTestCase(TestCase):
'count' : count
}
trade_gen_a = SpecificEquityTrades(*args_a, **kwargs_a)
monitor.add_to_topology(trade_gen_a.get_hash())
#Set up source b. Two minutes between events.
args_b = tuple()
@@ -109,7 +123,6 @@ class ComponentTestCase(TestCase):
'count' : count
}
trade_gen_b = SpecificEquityTrades(*args_b, **kwargs_b)
monitor.add_to_topology(trade_gen_b.get_hash())
#Set up source c. Three minutes between events.
args_c = tuple()
@@ -122,9 +135,7 @@ class ComponentTestCase(TestCase):
}
trade_gen_c = SpecificEquityTrades(*args_c, **kwargs_c)
monitor.add_to_topology(trade_gen_c.get_hash())
launch_monitor(monitor)
comp_a = Component(
trade_gen_a,
@@ -154,6 +165,8 @@ class ComponentTestCase(TestCase):
sorted_out = date_sorted_sources(*sources)
launch_monitor(monitor)
prev = None
sort_count = 0
for msg in sorted_out:
@@ -164,3 +177,160 @@ class ComponentTestCase(TestCase):
sort_count += 1
self.assertEqual(count*3, sort_count)
# wait for processes to finish
comp_a.proc.join()
comp_b.proc.join()
comp_c.proc.join()
def test_full(self):
monitor = create_monitor(allocator)
filter = [2,3]
#Set up source a. One minute between events.
args_a = tuple()
kwargs_a = {
'count' : 325,
'sids' : [1,2,3],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(hours = 6),
'filter' : filter
}
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
#Set up source b. Two minutes between events.
args_b = tuple()
kwargs_b = {
'count' : 7500,
'sids' : [2,3,4],
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 5),
'filter' : filter
}
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
# ------------------------
# Run sources in dedicated processes
comp_a = Component(
source_a,
monitor,
allocator.lease(1)[0],
DATASOURCE_FRAME,
DATASOURCE_UNFRAME,
source_a.get_hash()
)
comp_b = Component(
source_b,
monitor,
allocator.lease(1)[0],
DATASOURCE_FRAME,
DATASOURCE_UNFRAME,
source_b.get_hash()
)
# Date sort the sources, and run the sort in a dedicated
# process
sources = [comp_a, comp_b]
sorted_out = date_sorted_sources(*sources)
#launch_monitor(monitor)
#import nose.tools; nose.tools.set_trace()
#for feed_msg in sorted_out:
# log.info(pf(feed_msg))
#return
sorted = Component(
sorted_out,
monitor,
allocator.lease(1)[0],
FEED_FRAME,
FEED_UNFRAME,
"sort"
)
passthrough = StatefulTransform(Passthrough)
mavg_price = StatefulTransform(
MovingAverage,
timedelta(minutes = 20),
['price']
)
merged_gen = merged_transforms(sorted, passthrough, mavg_price)
merged = Component(
merged_gen,
monitor,
allocator.lease(1)[0],
MERGE_FRAME,
MERGE_UNFRAME,
"merge"
)
algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
environment = create_trading_environment(year = 2012)
style = SIMULATION_STYLE.FIXED_SLIPPAGE
trading_client = tsc(algo, environment, style)
launch_monitor(monitor)
for message in trading_client.simulate(merged):
log.info(pf(message))
# wait for processes to finish
comp_a.proc.join()
comp_b.proc.join()
sorted.proc.join()
merged.proc.join()
return
def test_compound(self):
monitor = create_monitor(allocator)
filter = [2,3]
#Set up source a. One minute between events.
args_a = tuple()
kwargs_a = {
'count' : 325,
'sids' : [1,2,3],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(hours = 6),
'filter' : filter
}
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
#Set up source b. Two minutes between events.
args_b = tuple()
kwargs_b = {
'count' : 7500,
'sids' : [2,3,4],
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 5),
'filter' : filter
}
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
sorted_out = date_sorted_sources(source_a, source_b)
sorted = Component(
sorted_out,
monitor,
allocator.lease(1)[0],
FEED_FRAME,
FEED_UNFRAME
)
launch_monitor(monitor)
for event in sorted:
log.info(event)
sorted.proc.join()
+33 -22
View File
@@ -51,7 +51,8 @@ class Component(object):
monitor,
socket_uri,
frame,
unframe
unframe,
component_id
):
# -----------------
@@ -59,7 +60,7 @@ class Component(object):
# -----------------
self.generator = generator
self.frame = frame
self.component_id = self.generator.get_hash()
self.component_id = component_id
# lock for waiting on monitor "GO"
self.waiting = None
@@ -99,15 +100,16 @@ class Component(object):
# first, start the generator in its own process. Once
# Monitor says "go", Events from the generator will be
# FRAME'd and PUSH'd to self.socket_uri.
proc = multiprocessing.Process(
target=self.loop_send
)
proc.start()
monitor.add_to_topology(self.component_id)
# ------------
# Message Receiver/Generator
# ------------
self.recv_gen = self.create_recv_gen()
self.proc = multiprocessing.Process(
target=self.loop_send
)
self.proc.start()
# Placeholder for receive generator, which will be
# created in __iter__
self.recv_gen = None
# ------------
@@ -123,8 +125,8 @@ class Component(object):
"""
try:
# The process title so you can watch it in top, ps.
setproctitle(self.generator.__class__.__name__)
self.prefix = "FORK-"
setproctitle(self.get_id)
log.info("Start %r" % self)
log.info("Pid %s" % os.getpid())
@@ -134,14 +136,15 @@ class Component(object):
self.signal_ready()
self.lock_ready()
self.wait_ready()
# -----------------------
# YOU SHALL NOT PASS!!!!!
# -----------------------
# ... until the monitor signals GO
msg = None
for event in self.generator:
if hasattr(event, 'dt') and event.dt == 'DONE':
continue
self.wait_ready()
self.heartbeat()
msg = self.frame(event)
self.out_socket.send(msg)
@@ -163,9 +166,6 @@ class Component(object):
def create_recv_gen(self):
try:
self.open(send=False)
self.signal_ready()
self.lock_ready()
# return the generator
return self.loop_recv()
except Exception as exc:
@@ -175,8 +175,12 @@ class Component(object):
def loop_recv(self):
try:
self.open(send=False)
self.signal_ready()
self.lock_ready()
# we block on ready here until monitor sends the GO
self.wait_ready()
# self.wait_ready()
for event in self.gen_from_poller(self.poll, self.in_socket, self.unframe):
yield event
@@ -189,7 +193,10 @@ class Component(object):
def gen_from_poller(self, poller, in_socket, unframe):
while True:
socks = dict(poller.poll(0))
# Since we will yield None to avoid blocking, we need
# to have a small delay to give the poller a chance
# to receive a message from upstream.
socks = dict(poller.poll(100))
self.heartbeat()
if socks.get(in_socket) == zmq.POLLIN:
message = in_socket.recv()
@@ -198,6 +205,8 @@ class Component(object):
else:
event = unframe(message)
yield event
else:
yield
def handle_exception(self, exc, re_raise=False):
if isinstance(exc, KillSignal):
@@ -215,6 +224,8 @@ class Component(object):
return self
def next(self):
if not self.recv_gen:
self.recv_gen = self.create_recv_gen()
return self.recv_gen.next()
# ----------------------------
+5 -5
View File
@@ -17,9 +17,9 @@ def merge(stream_in, tnfm_ids):
and merge them together into an event. We raise an error if we
do not receive the same number of events from all sources.
"""
assert isinstance(tnfm_ids, list)
# Set up an internal queue for each expected source.
tnfms = {}
for id in tnfm_ids:
@@ -36,7 +36,7 @@ def merge(stream_in, tnfm_ids):
id = message.tnfm_id
assert id in tnfm_ids, \
"Message from unexpected tnfm: %s, %s" % (id, tnfm_ids)
tnfms[id].append(message)
# Only pop messages when we have a pending message from
@@ -58,13 +58,13 @@ def merge_one(sources):
event_fields = ndict()
for key, queue in sources.iteritems():
# Add transform value to the transforms dict.
message = queue.popleft()
event_fields[message.tnfm_id] = message.tnfm_value
del message['tnfm_id']
del message['tnfm_value']
# Merge any remaining fields into the event dict.
event_fields.merge(message)
return event_fields
+9 -5
View File
@@ -53,16 +53,16 @@ class StatefulTransform(object):
"Stateful transform requires a class."
assert tnfm_class.__dict__.has_key('update'), \
"Stateful transform requires the class to have an update method"
self.forward_all = tnfm_class.__dict__.get('FORWARDER', False)
self.update_in_place = tnfm_class.__dict__.get('UPDATER', False)
# You can't be both a forwarded and an updater.
assert not all([self.forward_all, self.update_in_place])
# Create an instance of our transform class.
self.state = tnfm_class(*args, **kwargs)
# Create the string associated with this generator's output.
self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
@@ -76,7 +76,11 @@ class StatefulTransform(object):
# IMPORTANT: Messages may contain pointers that are shared with
# other streams, so we only manipulate copies.
for message in stream_in:
# allow upstream generators to yield None to avoid
# blocking.
if message == None:
continue
assert_sort_unframe_protocol(message)
message_copy = deepcopy(message)
@@ -90,7 +94,7 @@ class StatefulTransform(object):
out_message.tnfm_id = self.namestring
out_message.tnfm_value = tnfm_value
yield out_message
# Our expectation is that the transform simply updated the
# message it was passed. Useful for chaining together
# multiple transforms, e.g. TransactionSimulator/PerformanceTracker.
+3 -1
View File
@@ -49,7 +49,9 @@ def roundrobin(sources, namestrings):
for namestring, source in mapping.iteritems():
try:
message = source.next()
yield message
# allow sources to yield None to avoid blocking.
if message:
yield message
except StopIteration:
yield done_message(namestring)
del mapping[namestring]