mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 16:44:59 +08:00
Merge branch 'new_world_order' of github.com:quantopian/zipline into new_world_order
Conflicts: zipline/gens/examples.py
This commit is contained in:
+110
-38
@@ -4,17 +4,14 @@ from datetime import datetime, timedelta
|
||||
|
||||
from unittest2 import TestCase
|
||||
from collections import defaultdict
|
||||
from zipline.gens.composite import date_sorted_sources
|
||||
|
||||
from zipline.finance.trading import SIMULATION_STYLE
|
||||
from zipline.core.devsimulator import AddressAllocator
|
||||
from zipline.lines import SimulatedTrading
|
||||
|
||||
from zipline.utils.test_utils import (
|
||||
drain_zipline,
|
||||
check,
|
||||
setup_logger,
|
||||
teardown_logger,
|
||||
launch_component,
|
||||
create_monitor,
|
||||
launch_monitor
|
||||
)
|
||||
@@ -28,7 +25,7 @@ from zipline.protocol import (
|
||||
)
|
||||
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.utils import hash_args
|
||||
from zipline.gens.sort import date_sort
|
||||
from zipline.gens.zmqgen import gen_from_poller
|
||||
|
||||
import logbook
|
||||
@@ -53,10 +50,14 @@ class ComponentTestCase(TestCase):
|
||||
setup_logger(self)
|
||||
|
||||
def tearDown(self):
|
||||
self.ctx.term()
|
||||
#self.ctx.term()
|
||||
teardown_logger(self)
|
||||
|
||||
def test_specific_equity_source(self):
|
||||
def test_source(self):
|
||||
monitor = create_monitor(allocator)
|
||||
socket_uri = allocator.lease(1)[0]
|
||||
count = 100
|
||||
|
||||
filter = [1,2,3,4]
|
||||
#Set up source a. One minute between events.
|
||||
args_a = tuple()
|
||||
@@ -65,42 +66,113 @@ class ComponentTestCase(TestCase):
|
||||
'start' : datetime(2012,6,6,0,tzinfo=pytz.utc),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter,
|
||||
'count' : 100
|
||||
'count' : count
|
||||
}
|
||||
|
||||
c_id = SpecificEquityTrades.__name__ + hash_args(args_a, kwargs_a)
|
||||
mon = create_monitor(allocator)
|
||||
|
||||
out_socket_args = ComponentSocketArgs(
|
||||
style=zmq.PUSH,
|
||||
uri=allocator.lease(1)[0],
|
||||
bind=True
|
||||
comp_a = Component(
|
||||
SpecificEquityTrades,
|
||||
args_a,
|
||||
kwargs_a,
|
||||
monitor,
|
||||
socket_uri,
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME
|
||||
)
|
||||
|
||||
c = Component(
|
||||
SpecificEquityTrades,
|
||||
args_a,
|
||||
kwargs_a,
|
||||
c_id,
|
||||
out_socket_args,
|
||||
DATASOURCE_FRAME,
|
||||
mon
|
||||
)
|
||||
launch_monitor(monitor)
|
||||
|
||||
mon.manage(set([c.get_id]))
|
||||
mon_proc = launch_monitor(mon)
|
||||
for event in comp_a:
|
||||
log.info(event)
|
||||
|
||||
# launch in a process
|
||||
proc = launch_component(c)
|
||||
|
||||
pull_socket = self.ctx.socket(zmq.PULL)
|
||||
pull_socket.connect(out_socket_args.uri)
|
||||
poller = zmq.Poller()
|
||||
poller.register(pull_socket, zmq.POLLIN)
|
||||
unframe = DATASOURCE_UNFRAME
|
||||
for msg in gen_from_poller(poller, pull_socket, unframe):
|
||||
# assert things about the messages.
|
||||
log.info(msg)
|
||||
def test_sort(self):
|
||||
monitor = create_monitor(allocator)
|
||||
poller = zmq.Poller()
|
||||
socket_uris = allocator.lease(3)
|
||||
count = 100
|
||||
|
||||
pull_socket.close()
|
||||
log.info("DONE!")
|
||||
filter = [1,2,3,4]
|
||||
#Set up source a. One minute between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'sids' : [1,2],
|
||||
'start' : datetime(2012,6,6,0,tzinfo=pytz.utc),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter,
|
||||
'count' : count
|
||||
}
|
||||
|
||||
|
||||
comp_a = Component(
|
||||
SpecificEquityTrades,
|
||||
args_a,
|
||||
kwargs_a,
|
||||
monitor,
|
||||
socket_uris[0],
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME
|
||||
)
|
||||
|
||||
|
||||
#Set up source b. Two minutes between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'sids' : [2],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter,
|
||||
'count' : count
|
||||
}
|
||||
|
||||
|
||||
comp_b = Component(
|
||||
SpecificEquityTrades,
|
||||
args_b,
|
||||
kwargs_b,
|
||||
monitor,
|
||||
socket_uris[1],
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME
|
||||
)
|
||||
|
||||
#Set up source c. Three minutes between events.
|
||||
args_c = tuple()
|
||||
kwargs_c = {
|
||||
'sids' : [3],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter,
|
||||
'count' : count
|
||||
}
|
||||
|
||||
comp_c = Component(
|
||||
SpecificEquityTrades,
|
||||
args_c,
|
||||
kwargs_c,
|
||||
monitor,
|
||||
socket_uris[2],
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME
|
||||
)
|
||||
|
||||
names = [
|
||||
comp_a.get_id,
|
||||
comp_b.get_id,
|
||||
comp_c.get_id
|
||||
]
|
||||
|
||||
monitor.manage(set(names))
|
||||
launch_monitor(monitor)
|
||||
|
||||
sorted_out = date_sorted_sources([comp_a, comp_b, comp_c])
|
||||
|
||||
prev = None
|
||||
sort_count = 0
|
||||
for msg in sorted_out:
|
||||
if prev:
|
||||
self.assertTrue(msg.dt >= prev.dt, \
|
||||
"Messages should be in date ascending order")
|
||||
prev = msg
|
||||
sort_count += 1
|
||||
|
||||
self.assertEqual(count*3, sort_count)
|
||||
|
||||
+80
-42
@@ -10,8 +10,11 @@ import socket
|
||||
import logbook
|
||||
import traceback
|
||||
import humanhash
|
||||
import multiprocessing
|
||||
from setproctitle import setproctitle
|
||||
from collections import namedtuple
|
||||
from zipline.gens.utils import hash_args
|
||||
|
||||
|
||||
# pyzmq
|
||||
import zmq
|
||||
@@ -36,7 +39,7 @@ class KillSignal(Exception):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
ComponentSocketArgs = namedtuple('ComponentSocket',['uri','style','bind'])
|
||||
ComponentSocketArgs = namedtuple('ComponentSocketArgs',['uri','style','bind'])
|
||||
|
||||
class Component(object):
|
||||
|
||||
@@ -49,33 +52,27 @@ class Component(object):
|
||||
gen_args,
|
||||
gen_kwargs,
|
||||
component_id,
|
||||
out_socket_args,
|
||||
frame,
|
||||
monitor,
|
||||
in_socket_args=None,
|
||||
unframe=None
|
||||
socket_uri,
|
||||
frame,
|
||||
unframe
|
||||
):
|
||||
|
||||
assert component_id, \
|
||||
"Every component needs a unique and invariant identifier"
|
||||
assert isinstance(component_id, basestring), \
|
||||
"Components must have string IDs"
|
||||
assert isinstance(out_socket_args, ComponentSocketArgs), \
|
||||
"out_socket_args args must be ComponentSocketArgs"
|
||||
|
||||
if in_socket_args:
|
||||
assert isinstance(in_socket_args, ComponentSocketArgs), \
|
||||
"in_socket_args args must be ComponentSocketArgs"
|
||||
|
||||
# -----------------
|
||||
# Generator
|
||||
# -----------------
|
||||
self.component_id = component_id
|
||||
self.gen_args = gen_args
|
||||
self.gen_kwargs = gen_kwargs
|
||||
self.gen_func = gen_func
|
||||
self.generator = None
|
||||
self.frame = frame
|
||||
self.component_id = self.gen_func.__name__ \
|
||||
+ hash_args(gen_args, gen_kwargs)
|
||||
|
||||
# lock for waiting on monitor "GO"
|
||||
self.waiting = None
|
||||
@@ -83,14 +80,27 @@ class Component(object):
|
||||
# -----------------
|
||||
# ZMQ properties
|
||||
# -----------------
|
||||
self.in_socket_args = in_socket_args
|
||||
self.out_socket_args = out_socket_args
|
||||
self.in_socket_args = ComponentSocketArgs(
|
||||
uri = socket_uri,
|
||||
style = zmq.PULL,
|
||||
bind = False
|
||||
)
|
||||
self.out_socket_args = ComponentSocketArgs(
|
||||
uri = socket_uri,
|
||||
style = zmq.PUSH,
|
||||
bind = True
|
||||
)
|
||||
self.zmq = None
|
||||
self.context = None
|
||||
self.out_socket = None
|
||||
self.in_socket = None
|
||||
self.monitor = monitor
|
||||
self.monitor = monitor
|
||||
self.unframe = unframe
|
||||
self.prefix = ""
|
||||
|
||||
# register two components with the monitor
|
||||
monitor.add_to_topology(self.component_id)
|
||||
monitor.add_to_topology("FORK-"+self.component_id)
|
||||
|
||||
# TODO: state_flag is deprecated, remove
|
||||
self.state_flag = COMPONENT_STATE.OK
|
||||
@@ -109,7 +119,7 @@ class Component(object):
|
||||
# ------------
|
||||
|
||||
|
||||
def _run(self):
|
||||
def _run_out(self):
|
||||
"""
|
||||
The main component loop. This is wrapped inside a
|
||||
exception reporting context inside of run.
|
||||
@@ -118,13 +128,12 @@ class Component(object):
|
||||
"""
|
||||
# The process title so you can watch it in top, ps.
|
||||
setproctitle(self.gen_func.__name__)
|
||||
self.prefix = "FORK-"
|
||||
|
||||
log.info("Start %r" % self)
|
||||
log.info("Pid %s" % os.getpid())
|
||||
log.info("Group %s" % os.getpgrp())
|
||||
|
||||
self.sockets = []
|
||||
|
||||
self.open()
|
||||
|
||||
self.signal_ready()
|
||||
@@ -138,17 +147,36 @@ class Component(object):
|
||||
|
||||
for event in self.generator:
|
||||
self.heartbeat()
|
||||
event.source_id = self.get_id
|
||||
msg = self.frame(event)
|
||||
self.out_socket.send(msg)
|
||||
|
||||
self.signal_done()
|
||||
|
||||
def run(self, catch_exceptions=True):
|
||||
def _run_in(self):
|
||||
self.open(send=False)
|
||||
self.signal_ready()
|
||||
self.lock_ready()
|
||||
self.wait_ready()
|
||||
# -----------------------
|
||||
# YOU SHALL NOT PASS!!!!!
|
||||
# -----------------------
|
||||
# ... until the monitor signals GO
|
||||
|
||||
# return the generator
|
||||
for event in gen_from_poller(self.poll, self.in_socket, self.unframe):
|
||||
event.source_id = self.get_id
|
||||
yield event
|
||||
|
||||
self.signal_done()
|
||||
|
||||
def run_safe(self, func):
|
||||
"""
|
||||
Run the component.
|
||||
Run a function that is assumed to include wait_ready and
|
||||
heartbeat. Used to wrap fork_generator and consume_gen.
|
||||
"""
|
||||
try:
|
||||
self._run()
|
||||
return func()
|
||||
except Exception as exc:
|
||||
if not isinstance(exc, KillSignal):
|
||||
self.signal_exception(exc)
|
||||
@@ -160,6 +188,23 @@ class Component(object):
|
||||
log.info("Exiting %r" % self)
|
||||
|
||||
|
||||
def _launch(self):
|
||||
# first, start the generator in its own process. Once
|
||||
# Monitor says "go", Events from the generator will be
|
||||
# FRAME'd and PUSH'd to self.socket_uri.
|
||||
proc = multiprocessing.Process(
|
||||
target=self.run_safe,
|
||||
args=(self._run_out,)
|
||||
)
|
||||
proc.start()
|
||||
|
||||
# Start the poller-generator, which will PULL messages
|
||||
# from self.sockiet_uri, UNFRAME'd them, and yield them.
|
||||
return self.run_safe(self._run_in)
|
||||
|
||||
def __iter__(self):
|
||||
return self._launch()
|
||||
|
||||
# ----------------------------
|
||||
# Cleanup & Modes of Failure
|
||||
# ----------------------------
|
||||
@@ -420,8 +465,9 @@ class Component(object):
|
||||
# notify internal work loop that we're done
|
||||
self.done = True # TODO: use state flag
|
||||
|
||||
msg = zmq.Message(str(CONTROL_PROTOCOL.DONE))
|
||||
self.out_socket.send(msg)
|
||||
if self.out_socket:
|
||||
msg = zmq.Message(str(CONTROL_PROTOCOL.DONE))
|
||||
self.out_socket.send(msg)
|
||||
|
||||
|
||||
# notify monitor we're done
|
||||
@@ -437,40 +483,32 @@ class Component(object):
|
||||
# after the Monitor accepts our prior heartbeat, but just
|
||||
# before the next one is sent. So, we hang around for one
|
||||
# last heartbeat, and wait an unusually long time.
|
||||
self.heartbeat(timeout=5000)
|
||||
# TODO: decided if this is really necessary.
|
||||
# self.heartbeat(timeout=5000)
|
||||
|
||||
# -----------
|
||||
# Messaging
|
||||
# -----------
|
||||
|
||||
def open(self):
|
||||
def open(self, send=True):
|
||||
"""
|
||||
Open the connections needed to start doing work.
|
||||
Perform any setup that must be done within process.
|
||||
"""
|
||||
|
||||
self.sockets = []
|
||||
self.zmq = zmq
|
||||
self.context = self.zmq.Context()
|
||||
self.poll = self.zmq.Poller()
|
||||
|
||||
self.setup_control()
|
||||
|
||||
if self.in_socket_args:
|
||||
self.in_socket = self.open_socket(self.in_socket_args)
|
||||
poller_gen = gen_from_poller(
|
||||
self.poller,
|
||||
self.in_socket,
|
||||
self.unframe
|
||||
)
|
||||
self.generator = self.gen_func(
|
||||
poller_gen,
|
||||
*self.gen_args,
|
||||
**self.gen_kwargs
|
||||
)
|
||||
else:
|
||||
if send:
|
||||
self.generator = self.gen_func(*self.gen_args, **self.gen_kwargs)
|
||||
|
||||
self.out_socket = self.open_socket(self.out_socket_args)
|
||||
self.out_socket = self.open_socket(self.out_socket_args)
|
||||
self.sockets.extend([self.out_socket])
|
||||
else:
|
||||
self.in_socket = self.open_socket(self.in_socket_args)
|
||||
self.sockets.extend([self.in_socket])
|
||||
|
||||
def open_socket(self, sock_args):
|
||||
if sock_args.bind:
|
||||
@@ -577,7 +615,7 @@ class Component(object):
|
||||
The time invariant name for this component.
|
||||
Must be unique within this zipline.
|
||||
"""
|
||||
return self.component_id
|
||||
return self.prefix + self.component_id
|
||||
|
||||
def debug(self):
|
||||
"""
|
||||
|
||||
@@ -105,6 +105,9 @@ class Monitor(object):
|
||||
|
||||
self.missed_beats = Counter()
|
||||
|
||||
# start with an empty topology
|
||||
self.topology = set([])
|
||||
|
||||
self.send_sighup = send_sighup
|
||||
if self.send_sighup:
|
||||
log.info("Request to send sighup/sigint")
|
||||
@@ -116,6 +119,17 @@ class Monitor(object):
|
||||
self.zmq_poller = self.zmq.Poller
|
||||
return
|
||||
|
||||
def add_to_topology(self, component_id):
|
||||
add = set([component_id])
|
||||
self.topology.update(add)
|
||||
|
||||
def freeze_topology(self):
|
||||
if isinstance(self.topology, frozenset):
|
||||
return
|
||||
# we've been incrementally adding components.
|
||||
# time to freeze.
|
||||
self.manage(self.topology)
|
||||
|
||||
def manage(self, topology):
|
||||
"""
|
||||
Give the controller a set set of components to manage and
|
||||
@@ -147,6 +161,7 @@ class Monitor(object):
|
||||
raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new))
|
||||
|
||||
def run(self):
|
||||
self.freeze_topology()
|
||||
self.running = True
|
||||
self.init_zmq()
|
||||
setproctitle('Monitor')
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
import pytz
|
||||
from time import sleep
|
||||
|
||||
from pprint import pprint as pp
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from zipline.utils.factory import create_trading_environment
|
||||
@@ -26,7 +23,7 @@ if __name__ == "__main__":
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
bundle_a = SourceBundle(SpecificEquityTrades, args_a, kwargs_a)
|
||||
|
||||
#Set up source b. Two minutes between events.
|
||||
args_b = tuple()
|
||||
@@ -36,9 +33,10 @@ if __name__ == "__main__":
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
source_b = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
bundle_b = SourceBundle(SpecificEquityTrades, args_b, kwargs_b)
|
||||
|
||||
#Set up source c. Three minutes between events.
|
||||
|
||||
sort_out = date_sorted_sources(source_a, source_b)
|
||||
|
||||
# passthrough = TransformBundle(Passthrough, (), {})
|
||||
@@ -58,6 +56,4 @@ if __name__ == "__main__":
|
||||
# for message in client_out:
|
||||
# pp(message)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
+14
-14
@@ -7,7 +7,7 @@ from itertools import chain, cycle, ifilter, izip
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from zipline.utils.factory import create_trade
|
||||
from zipline.gens.utils import hash_args, mock_done
|
||||
from zipline.gens.utils import hash_args
|
||||
|
||||
def date_gen(start = datetime(2006, 6, 6, 12),
|
||||
delta = timedelta(minutes = 1),
|
||||
@@ -54,9 +54,9 @@ class SpecificEquityTrades(object):
|
||||
Yields all events in event_list that match the given sid_filter.
|
||||
If no event_list is specified, generates an internal stream of events
|
||||
to filter. Returns all events if filter is None.
|
||||
|
||||
|
||||
Configuration options:
|
||||
|
||||
|
||||
count : integer representing number of trades
|
||||
sids : list of values representing simulated internal sids
|
||||
start : start date
|
||||
@@ -67,22 +67,22 @@ class SpecificEquityTrades(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
# We shouldn't get any positional arguments.
|
||||
assert len(args) == 0
|
||||
|
||||
|
||||
# Unpack config dictionary with default values.
|
||||
self.count = kwargs.get('count', 500)
|
||||
self.sids = kwargs.get('sids', [1, 2])
|
||||
self.start = kwargs.get('start', datetime(2012, 6, 6, 0))
|
||||
self.delta = kwargs.get('delta', timedelta(minutes = 1))
|
||||
|
||||
|
||||
# Default to None for event_list and filter.
|
||||
self.event_list = kwargs.get('event_list')
|
||||
self.filter = kwargs.get('filter')
|
||||
|
||||
|
||||
# Hash_value for downstream sorting.
|
||||
self.arg_string = hash_args(*args, **kwargs)
|
||||
|
||||
|
||||
self.generator = self.create_fresh_generator()
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
return self.generator
|
||||
|
||||
@@ -94,22 +94,22 @@ class SpecificEquityTrades(object):
|
||||
|
||||
def get_hash(self):
|
||||
return self.__class__.__name__ + "-" + self.arg_string
|
||||
|
||||
|
||||
def create_fresh_generator(self):
|
||||
|
||||
|
||||
if self.event_list:
|
||||
unfiltered = (event for event in self.event_list)
|
||||
|
||||
# Set up iterators for each expected field.
|
||||
else:
|
||||
dates = date_gen(count=self.count,
|
||||
start=self.start,
|
||||
dates = date_gen(count=self.count,
|
||||
start=self.start,
|
||||
delta=self.delta
|
||||
)
|
||||
prices = mock_prices(self.count)
|
||||
volumes = mock_volumes(self.count)
|
||||
sids = cycle(self.sids)
|
||||
|
||||
|
||||
# Combine the iterators into a single iterator of arguments
|
||||
arg_gen = izip(sids, prices, volumes, dates)
|
||||
|
||||
@@ -137,7 +137,7 @@ def RandomEquityTrades(object):
|
||||
def __init__(self):
|
||||
# We shouldn't get any positional args.
|
||||
assert args == ()
|
||||
|
||||
|
||||
self.count = config.get('count', 500)
|
||||
self.sids = config.get('sids', [1,2])
|
||||
self.filter = config.get('filter')
|
||||
|
||||
@@ -13,6 +13,11 @@ def gen_from_pull_socket(socket_uri, context, unframe):
|
||||
|
||||
return gen_from_poller(poller, pull_socket, unframe)
|
||||
|
||||
|
||||
# this generator needs to know about the source_ids coming in via
|
||||
# the poller, and need to yield DONE messages for each
|
||||
# source_id.
|
||||
|
||||
def gen_from_poller(poller, in_socket, unframe):
|
||||
|
||||
while True:
|
||||
|
||||
Reference in New Issue
Block a user