mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 11:52:39 +08:00
modified zmq_gen method to yield None when there is no waiting message. This prevents blocking in the next() method of a component. But it requires generators wrapping the component to handle None.
Also modified component's receiver creation to be triggered on the first call to next, rather than iter. This change means that the zmq context and socket for the component's receiver should always be created in the same process as the consumer of the generator. Chaining together component wrapped generators will result in the send process of the last component actually instantiating the receive socket of the prior component. In this way, the components are actually communicating directly via zmq. Component's send method now calls the wait_ready(), which waits for the monitor's GO message, inside the generator loop. This guarantees that the generator's next method is called before the send loop blocks on the monitor. As a result, components will call __init__ and next() without blocking, mimicking the behavior of plain generators.
This commit is contained in:
+178
-8
@@ -1,13 +1,20 @@
|
||||
import zmq
|
||||
import pytz
|
||||
from pprint import pformat as pf
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from unittest2 import TestCase
|
||||
from collections import defaultdict
|
||||
from zipline.gens.composites import date_sorted_sources
|
||||
from zipline.gens.composites import date_sorted_sources, merged_transforms
|
||||
|
||||
from zipline.finance.trading import SIMULATION_STYLE
|
||||
from zipline.core.devsimulator import AddressAllocator
|
||||
from zipline.gens.transform import MovingAverage, Passthrough, StatefulTransform
|
||||
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
|
||||
|
||||
from zipline.utils.factory import create_trading_environment
|
||||
from zipline.test_algorithms import TestAlgorithm
|
||||
|
||||
|
||||
from zipline.utils.test_utils import (
|
||||
setup_logger,
|
||||
@@ -19,7 +26,12 @@ from zipline.utils.test_utils import (
|
||||
from zipline.core import Component
|
||||
from zipline.protocol import (
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME
|
||||
DATASOURCE_UNFRAME,
|
||||
FEED_FRAME,
|
||||
FEED_UNFRAME,
|
||||
MERGE_FRAME,
|
||||
MERGE_UNFRAME,
|
||||
SIMULATION_STYLE
|
||||
)
|
||||
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
@@ -65,9 +77,7 @@ class ComponentTestCase(TestCase):
|
||||
}
|
||||
|
||||
trade_gen = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
monitor.add_to_topology(trade_gen.get_hash())
|
||||
|
||||
launch_monitor(monitor)
|
||||
|
||||
comp_a = Component(
|
||||
trade_gen,
|
||||
@@ -77,9 +87,14 @@ class ComponentTestCase(TestCase):
|
||||
DATASOURCE_UNFRAME
|
||||
)
|
||||
|
||||
launch_monitor(monitor)
|
||||
|
||||
for event in comp_a:
|
||||
log.info(event)
|
||||
|
||||
# wait for the sending process to exit
|
||||
comp_a.proc.join()
|
||||
|
||||
|
||||
def test_sort(self):
|
||||
monitor = create_monitor(allocator)
|
||||
@@ -97,7 +112,6 @@ class ComponentTestCase(TestCase):
|
||||
'count' : count
|
||||
}
|
||||
trade_gen_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
monitor.add_to_topology(trade_gen_a.get_hash())
|
||||
|
||||
#Set up source b. Two minutes between events.
|
||||
args_b = tuple()
|
||||
@@ -109,7 +123,6 @@ class ComponentTestCase(TestCase):
|
||||
'count' : count
|
||||
}
|
||||
trade_gen_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
monitor.add_to_topology(trade_gen_b.get_hash())
|
||||
|
||||
#Set up source c. Three minutes between events.
|
||||
args_c = tuple()
|
||||
@@ -122,9 +135,7 @@ class ComponentTestCase(TestCase):
|
||||
}
|
||||
|
||||
trade_gen_c = SpecificEquityTrades(*args_c, **kwargs_c)
|
||||
monitor.add_to_topology(trade_gen_c.get_hash())
|
||||
|
||||
launch_monitor(monitor)
|
||||
|
||||
comp_a = Component(
|
||||
trade_gen_a,
|
||||
@@ -154,6 +165,8 @@ class ComponentTestCase(TestCase):
|
||||
|
||||
sorted_out = date_sorted_sources(*sources)
|
||||
|
||||
launch_monitor(monitor)
|
||||
|
||||
prev = None
|
||||
sort_count = 0
|
||||
for msg in sorted_out:
|
||||
@@ -164,3 +177,160 @@ class ComponentTestCase(TestCase):
|
||||
sort_count += 1
|
||||
|
||||
self.assertEqual(count*3, sort_count)
|
||||
|
||||
# wait for processes to finish
|
||||
comp_a.proc.join()
|
||||
comp_b.proc.join()
|
||||
comp_c.proc.join()
|
||||
|
||||
|
||||
def test_full(self):
|
||||
monitor = create_monitor(allocator)
|
||||
|
||||
filter = [2,3]
|
||||
#Set up source a. One minute between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'count' : 325,
|
||||
'sids' : [1,2,3],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(hours = 6),
|
||||
'filter' : filter
|
||||
}
|
||||
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
#Set up source b. Two minutes between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'count' : 7500,
|
||||
'sids' : [2,3,4],
|
||||
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 5),
|
||||
'filter' : filter
|
||||
}
|
||||
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
|
||||
# ------------------------
|
||||
# Run sources in dedicated processes
|
||||
comp_a = Component(
|
||||
source_a,
|
||||
monitor,
|
||||
allocator.lease(1)[0],
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME,
|
||||
source_a.get_hash()
|
||||
)
|
||||
|
||||
comp_b = Component(
|
||||
source_b,
|
||||
monitor,
|
||||
allocator.lease(1)[0],
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME,
|
||||
source_b.get_hash()
|
||||
)
|
||||
|
||||
# Date sort the sources, and run the sort in a dedicated
|
||||
# process
|
||||
sources = [comp_a, comp_b]
|
||||
|
||||
sorted_out = date_sorted_sources(*sources)
|
||||
|
||||
#launch_monitor(monitor)
|
||||
#import nose.tools; nose.tools.set_trace()
|
||||
#for feed_msg in sorted_out:
|
||||
# log.info(pf(feed_msg))
|
||||
|
||||
#return
|
||||
|
||||
sorted = Component(
|
||||
sorted_out,
|
||||
monitor,
|
||||
allocator.lease(1)[0],
|
||||
FEED_FRAME,
|
||||
FEED_UNFRAME,
|
||||
"sort"
|
||||
)
|
||||
|
||||
|
||||
passthrough = StatefulTransform(Passthrough)
|
||||
mavg_price = StatefulTransform(
|
||||
MovingAverage,
|
||||
timedelta(minutes = 20),
|
||||
['price']
|
||||
)
|
||||
|
||||
merged_gen = merged_transforms(sorted, passthrough, mavg_price)
|
||||
|
||||
merged = Component(
|
||||
merged_gen,
|
||||
monitor,
|
||||
allocator.lease(1)[0],
|
||||
MERGE_FRAME,
|
||||
MERGE_UNFRAME,
|
||||
"merge"
|
||||
)
|
||||
|
||||
algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
|
||||
environment = create_trading_environment(year = 2012)
|
||||
style = SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
|
||||
trading_client = tsc(algo, environment, style)
|
||||
|
||||
launch_monitor(monitor)
|
||||
for message in trading_client.simulate(merged):
|
||||
log.info(pf(message))
|
||||
|
||||
|
||||
# wait for processes to finish
|
||||
comp_a.proc.join()
|
||||
comp_b.proc.join()
|
||||
sorted.proc.join()
|
||||
merged.proc.join()
|
||||
return
|
||||
|
||||
|
||||
|
||||
def test_compound(self):
|
||||
monitor = create_monitor(allocator)
|
||||
|
||||
filter = [2,3]
|
||||
#Set up source a. One minute between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'count' : 325,
|
||||
'sids' : [1,2,3],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(hours = 6),
|
||||
'filter' : filter
|
||||
}
|
||||
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
#Set up source b. Two minutes between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'count' : 7500,
|
||||
'sids' : [2,3,4],
|
||||
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 5),
|
||||
'filter' : filter
|
||||
}
|
||||
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
|
||||
sorted_out = date_sorted_sources(source_a, source_b)
|
||||
|
||||
sorted = Component(
|
||||
sorted_out,
|
||||
monitor,
|
||||
allocator.lease(1)[0],
|
||||
FEED_FRAME,
|
||||
FEED_UNFRAME
|
||||
)
|
||||
|
||||
launch_monitor(monitor)
|
||||
|
||||
for event in sorted:
|
||||
log.info(event)
|
||||
|
||||
|
||||
sorted.proc.join()
|
||||
|
||||
+33
-22
@@ -51,7 +51,8 @@ class Component(object):
|
||||
monitor,
|
||||
socket_uri,
|
||||
frame,
|
||||
unframe
|
||||
unframe,
|
||||
component_id
|
||||
):
|
||||
|
||||
# -----------------
|
||||
@@ -59,7 +60,7 @@ class Component(object):
|
||||
# -----------------
|
||||
self.generator = generator
|
||||
self.frame = frame
|
||||
self.component_id = self.generator.get_hash()
|
||||
self.component_id = component_id
|
||||
|
||||
# lock for waiting on monitor "GO"
|
||||
self.waiting = None
|
||||
@@ -99,15 +100,16 @@ class Component(object):
|
||||
# first, start the generator in its own process. Once
|
||||
# Monitor says "go", Events from the generator will be
|
||||
# FRAME'd and PUSH'd to self.socket_uri.
|
||||
proc = multiprocessing.Process(
|
||||
target=self.loop_send
|
||||
)
|
||||
proc.start()
|
||||
monitor.add_to_topology(self.component_id)
|
||||
|
||||
# ------------
|
||||
# Message Receiver/Generator
|
||||
# ------------
|
||||
self.recv_gen = self.create_recv_gen()
|
||||
self.proc = multiprocessing.Process(
|
||||
target=self.loop_send
|
||||
)
|
||||
self.proc.start()
|
||||
|
||||
# Placeholder for receive generator, which will be
|
||||
# created in __iter__
|
||||
self.recv_gen = None
|
||||
|
||||
|
||||
# ------------
|
||||
@@ -123,8 +125,8 @@ class Component(object):
|
||||
"""
|
||||
try:
|
||||
# The process title so you can watch it in top, ps.
|
||||
setproctitle(self.generator.__class__.__name__)
|
||||
self.prefix = "FORK-"
|
||||
setproctitle(self.get_id)
|
||||
|
||||
log.info("Start %r" % self)
|
||||
log.info("Pid %s" % os.getpid())
|
||||
@@ -134,14 +136,15 @@ class Component(object):
|
||||
|
||||
self.signal_ready()
|
||||
self.lock_ready()
|
||||
self.wait_ready()
|
||||
|
||||
# -----------------------
|
||||
# YOU SHALL NOT PASS!!!!!
|
||||
# -----------------------
|
||||
# ... until the monitor signals GO
|
||||
|
||||
msg = None
|
||||
for event in self.generator:
|
||||
|
||||
if hasattr(event, 'dt') and event.dt == 'DONE':
|
||||
continue
|
||||
|
||||
self.wait_ready()
|
||||
|
||||
self.heartbeat()
|
||||
msg = self.frame(event)
|
||||
self.out_socket.send(msg)
|
||||
@@ -163,9 +166,6 @@ class Component(object):
|
||||
|
||||
def create_recv_gen(self):
|
||||
try:
|
||||
self.open(send=False)
|
||||
self.signal_ready()
|
||||
self.lock_ready()
|
||||
# return the generator
|
||||
return self.loop_recv()
|
||||
except Exception as exc:
|
||||
@@ -175,8 +175,12 @@ class Component(object):
|
||||
|
||||
def loop_recv(self):
|
||||
try:
|
||||
self.open(send=False)
|
||||
self.signal_ready()
|
||||
self.lock_ready()
|
||||
|
||||
# we block on ready here until monitor sends the GO
|
||||
self.wait_ready()
|
||||
# self.wait_ready()
|
||||
for event in self.gen_from_poller(self.poll, self.in_socket, self.unframe):
|
||||
yield event
|
||||
|
||||
@@ -189,7 +193,10 @@ class Component(object):
|
||||
def gen_from_poller(self, poller, in_socket, unframe):
|
||||
|
||||
while True:
|
||||
socks = dict(poller.poll(0))
|
||||
# Since we will yield None to avoid blocking, we need
|
||||
# to have a small delay to give the poller a chance
|
||||
# to receive a message from upstream.
|
||||
socks = dict(poller.poll(100))
|
||||
self.heartbeat()
|
||||
if socks.get(in_socket) == zmq.POLLIN:
|
||||
message = in_socket.recv()
|
||||
@@ -198,6 +205,8 @@ class Component(object):
|
||||
else:
|
||||
event = unframe(message)
|
||||
yield event
|
||||
else:
|
||||
yield
|
||||
|
||||
def handle_exception(self, exc, re_raise=False):
|
||||
if isinstance(exc, KillSignal):
|
||||
@@ -215,6 +224,8 @@ class Component(object):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
if not self.recv_gen:
|
||||
self.recv_gen = self.create_recv_gen()
|
||||
return self.recv_gen.next()
|
||||
|
||||
# ----------------------------
|
||||
|
||||
@@ -17,9 +17,9 @@ def merge(stream_in, tnfm_ids):
|
||||
and merge them together into an event. We raise an error if we
|
||||
do not receive the same number of events from all sources.
|
||||
"""
|
||||
|
||||
|
||||
assert isinstance(tnfm_ids, list)
|
||||
|
||||
|
||||
# Set up an internal queue for each expected source.
|
||||
tnfms = {}
|
||||
for id in tnfm_ids:
|
||||
@@ -36,7 +36,7 @@ def merge(stream_in, tnfm_ids):
|
||||
id = message.tnfm_id
|
||||
assert id in tnfm_ids, \
|
||||
"Message from unexpected tnfm: %s, %s" % (id, tnfm_ids)
|
||||
|
||||
|
||||
tnfms[id].append(message)
|
||||
|
||||
# Only pop messages when we have a pending message from
|
||||
@@ -58,13 +58,13 @@ def merge_one(sources):
|
||||
event_fields = ndict()
|
||||
|
||||
for key, queue in sources.iteritems():
|
||||
|
||||
|
||||
# Add transform value to the transforms dict.
|
||||
message = queue.popleft()
|
||||
event_fields[message.tnfm_id] = message.tnfm_value
|
||||
del message['tnfm_id']
|
||||
del message['tnfm_value']
|
||||
|
||||
|
||||
# Merge any remaining fields into the event dict.
|
||||
event_fields.merge(message)
|
||||
return event_fields
|
||||
|
||||
@@ -53,16 +53,16 @@ class StatefulTransform(object):
|
||||
"Stateful transform requires a class."
|
||||
assert tnfm_class.__dict__.has_key('update'), \
|
||||
"Stateful transform requires the class to have an update method"
|
||||
|
||||
|
||||
self.forward_all = tnfm_class.__dict__.get('FORWARDER', False)
|
||||
self.update_in_place = tnfm_class.__dict__.get('UPDATER', False)
|
||||
|
||||
# You can't be both a forwarded and an updater.
|
||||
assert not all([self.forward_all, self.update_in_place])
|
||||
|
||||
|
||||
# Create an instance of our transform class.
|
||||
self.state = tnfm_class(*args, **kwargs)
|
||||
|
||||
|
||||
# Create the string associated with this generator's output.
|
||||
self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
|
||||
|
||||
@@ -76,7 +76,11 @@ class StatefulTransform(object):
|
||||
# IMPORTANT: Messages may contain pointers that are shared with
|
||||
# other streams, so we only manipulate copies.
|
||||
for message in stream_in:
|
||||
|
||||
# allow upstream generators to yield None to avoid
|
||||
# blocking.
|
||||
if message == None:
|
||||
continue
|
||||
|
||||
assert_sort_unframe_protocol(message)
|
||||
message_copy = deepcopy(message)
|
||||
|
||||
@@ -90,7 +94,7 @@ class StatefulTransform(object):
|
||||
out_message.tnfm_id = self.namestring
|
||||
out_message.tnfm_value = tnfm_value
|
||||
yield out_message
|
||||
|
||||
|
||||
# Our expectation is that the transform simply updated the
|
||||
# message it was passed. Useful for chaining together
|
||||
# multiple transforms, e.g. TransactionSimulator/PerformanceTracker.
|
||||
|
||||
@@ -49,7 +49,9 @@ def roundrobin(sources, namestrings):
|
||||
for namestring, source in mapping.iteritems():
|
||||
try:
|
||||
message = source.next()
|
||||
yield message
|
||||
# allow sources to yield None to avoid blocking.
|
||||
if message:
|
||||
yield message
|
||||
except StopIteration:
|
||||
yield done_message(namestring)
|
||||
del mapping[namestring]
|
||||
|
||||
Reference in New Issue
Block a user