Cleanup docs, pep8ify, and backport dev simulator test

This commit is contained in:
Stephen Diehl
2012-02-23 15:32:17 -05:00
parent 6190da376a
commit e640aabf73
14 changed files with 251 additions and 507 deletions
+1 -1
View File
@@ -40,7 +40,7 @@ source_suffix = '.rst'
master_doc = 'index'
# General information about the project.
project = u'QSim'
project = u'Zipline'
copyright = u'2012, Quantopian: jean, fawce, sdiehl'
# The version info for the project you're documenting, acts as replacement for
+7 -7
View File
@@ -1,4 +1,4 @@
.. QSim documentation master file, created by
.. Zipline documentation master file, created by
sphinx-quickstart on Wed Feb 8 15:29:56 2012.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
@@ -12,15 +12,15 @@ Contents:
modules.rst
messaging.rst
Quantopian Simulator: QSim
================================
Zipline
=======
Qsim runs backtests using asynchronous components and zeromq messaging for communication and coordination.
Zipline runs backtests using asynchronous components and zeromq messaging for communication and coordination.
Simulator is the heart of QSim, and the primary access point for creating, launching, and tracking simulations. You can find it in :py:class:`~zipline.core.Simulator`
Simulator is the heart of Zipline, and the primary access point for creating, launching, and tracking simulations. You can find it in :py:class:`~zipline.core.Simulator`
Simulator Sub-Components
==========================
========================
Each simulation contains numerous subcomponents, each operating asynchronously from all others, and communicating
via zeromq.
@@ -61,4 +61,4 @@ Indices and tables
* :ref:`modindex`
* :ref:`search`
.. _coverage: cover/index.html
.. _coverage: cover/index.html
-63
View File
@@ -1,63 +0,0 @@
============
zipline Package
============
QSim API
===========================
:mod:`zipline` Package
-------------------
.. automodule:: zipline.__init__
:members:
:undoc-members:
:show-inheritance:
:mod:`config` Module
--------------------
.. automodule:: zipline.config
:members:
:undoc-members:
:show-inheritance:
:mod:`core` Module
------------------
.. automodule:: zipline.core
:members:
:undoc-members:
:show-inheritance:
:mod:`messaging` Module
-----------------------
.. automodule:: zipline.messaging
:members:
:undoc-members:
:show-inheritance:
:mod:`sources` Module
---------------------
.. automodule:: zipline.sources
:members:
:undoc-members:
:show-inheritance:
:mod:`util` Module
------------------
.. automodule:: zipline.util
:members:
:undoc-members:
:show-inheritance:
Subpackages
-----------
.. toctree::
zipline.test
zipline.transforms
-19
View File
@@ -1,19 +0,0 @@
test Package
============
:mod:`client` Module
--------------------
.. automodule:: zipline.test.client
:members:
:undoc-members:
:show-inheritance:
:mod:`test_messaging` Module
----------------------------
.. automodule:: zipline.test.test_messaging
:members:
:undoc-members:
:show-inheritance:
-27
View File
@@ -1,27 +0,0 @@
transforms Package
==================
:mod:`core` Module
------------------
.. automodule:: zipline.transforms.core
:members:
:undoc-members:
:show-inheritance:
:mod:`merge` Module
-------------------
.. automodule:: zipline.transforms.merge
:members:
:undoc-members:
:show-inheritance:
:mod:`technical` Module
-----------------------
.. automodule:: zipline.transforms.technical
:members:
:undoc-members:
:show-inheritance:
+48
View File
@@ -9,6 +9,22 @@ zipline Package
:undoc-members:
:show-inheritance:
:mod:`cli` Module
-----------------
.. automodule:: zipline.cli
:members:
:undoc-members:
:show-inheritance:
:mod:`component` Module
-----------------------
.. automodule:: zipline.component
:members:
:undoc-members:
:show-inheritance:
:mod:`messaging` Module
-----------------------
@@ -17,6 +33,22 @@ zipline Package
:undoc-members:
:show-inheritance:
:mod:`monitor` Module
---------------------
.. automodule:: zipline.monitor
:members:
:undoc-members:
:show-inheritance:
:mod:`protocol` Module
----------------------
.. automodule:: zipline.protocol
:members:
:undoc-members:
:show-inheritance:
:mod:`sources` Module
---------------------
@@ -25,6 +57,14 @@ zipline Package
:undoc-members:
:show-inheritance:
:mod:`topology` Module
----------------------
.. automodule:: zipline.topology
:members:
:undoc-members:
:show-inheritance:
:mod:`util` Module
------------------
@@ -33,6 +73,14 @@ zipline Package
:undoc-members:
:show-inheritance:
:mod:`webui` Module
-------------------
.. automodule:: zipline.webui
:members:
:undoc-members:
:show-inheritance:
Subpackages
-----------
+8
View File
@@ -17,3 +17,11 @@ test Package
:undoc-members:
:show-inheritance:
:mod:`test_sanity` Module
-------------------------
.. automodule:: zipline.test.test_sanity
:members:
:undoc-members:
:show-inheritance:
+1 -1
View File
@@ -185,5 +185,5 @@ def apidocs():
Recursively autogenerate the Sphinx autodoc for the module and
its submodules.
"""
call('rm docs/zipline.*.rst', shell=True)
call('rm docs/zipline*.rst', shell=True)
call('sphinx-apidoc -o docs/ zipline', shell=True)
+2 -327
View File
@@ -200,336 +200,11 @@ class Component(object):
def setup_sync(self):
qutil.LOGGER.debug("Connecting sync client for {id}".format(id=self.get_id()))
self.sync_socket = self.context.socket(self.zmq.REQ)
self.sync_socket.connect(self.addresses['sync_address'])
#self.sync_socket.setsockopt(self.zmq.LINGER,0)
self.sync_poller = self.zmq.Poller()
self.sync_poller.register(self.sync_socket, self.zmq.POLLIN)
self.sockets.append(self.sync_socket)
class ComponentHost(Component):
"""
Component that can launch multiple sub-components, synchronize their start, and then wait for all
components to be finished.
"""
def __init__(self, addresses, gevent_needed=False):
Component.__init__(self)
self.addresses = addresses
#workaround for defect in threaded use of strptime: http://bugs.python.org/issue11108
qutil.parse_date("2012/02/13-10:04:28.114")
self.components = {}
self.sync_register = {}
self.timeout = datetime.timedelta(seconds=5)
self.feed = ParallelBuffer()
self.merge = MergedParallelBuffer()
self.passthrough = PassthroughTransform()
self.gevent_needed = gevent_needed
self.controller = None
#register the feed and the merge
self.register_components([self.feed, self.merge, self.passthrough])
def register_controller(self, controller):
self.controller = controller
for component in self.components.itervalues():
component.controller = controller
def register_components(self, components):
for component in components:
component.gevent_needed = self.gevent_needed
component.addresses = self.addresses
if self.controller:
component.controller = self.controller
self.components[component.get_id()] = component
self.sync_register[component.get_id()] = datetime.datetime.utcnow()
if(isinstance(component, DataSource)):
self.feed.add_source(component.get_id())
if(isinstance(component, BaseTransform)):
self.merge.add_source(component.get_id())
def unregister_component(self, component_id):
del(self.components[component_id])
del(self.sync_register[component_id])
def setup_sync(self):
"""Start the sync server."""
qutil.LOGGER.debug("Connecting sync server.")
self.sync_socket = self.context.socket(self.zmq.REP)
self.sync_socket.bind(self.addresses['sync_address'])
self.poller = self.zmq.Poller()
self.poller.register(self.sync_socket, self.zmq.POLLIN)
self.sockets.append(self.sync_socket)
def open(self):
for component in self.components.values():
self.launch_component(component)
self.launch_controller()
def is_timed_out(self):
cur_time = datetime.datetime.utcnow()
if(len(self.components) == 0):
qutil.LOGGER.info("Component register is empty.")
return True
for source, last_dt in self.sync_register.iteritems():
if((cur_time - last_dt) > self.timeout):
qutil.LOGGER.info("Time out for {source}. Current component registery: {reg}".format(source=source, reg=self.components))
return True
return False
def loop(self):
while not self.is_timed_out():
# wait for synchronization request
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.sync_socket in socks and socks[self.sync_socket] == self.zmq.POLLIN:
msg = self.sync_socket.recv()
parts = msg.split(':')
if(len(parts) < 2):
qutil.LOGGER.info("got bad confirm: {msg}".format(msg=msg))
sync_id = parts[0]
status = parts[1]
if(self.is_done_message(status)):
qutil.LOGGER.info("{id} is DONE".format(id=sync_id))
self.unregister_component(sync_id)
else:
self.sync_register[sync_id] = datetime.datetime.utcnow()
#qutil.LOGGER.info("confirmed {id}".format(id=msg))
# send synchronization reply
self.sync_socket.send('ack', self.zmq.NOBLOCK)
def launch_controller(self, controller):
NotImplemented
def launch_component(self, component):
NotImplemented
class ParallelBuffer(Component):
"""Connects to N PULL sockets, publishing all messages received to a PUB socket.
Published messages are guaranteed to be in chronological order based on message property dt.
Expects to be instantiated in one execution context (thread, process, etc) and run in another."""
def __init__(self):
Component.__init__(self)
self.sent_count = 0
self.received_count = 0
self.draining = False
self.data_buffer = {}
self.ds_finished_counter = 0
def get_id(self):
return "FEED"
def add_source(self, source_id):
self.data_buffer[source_id] = []
def open(self):
self.pull_socket, self.poller = self.bind_data()
self.feed_socket = self.bind_feed()
def do_work(self):
# wait for synchronization reply from the host
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.pull_socket in socks and socks[self.pull_socket] == self.zmq.POLLIN:
message = self.pull_socket.recv()
if(self.is_done_message(message)):
self.ds_finished_counter += 1
if(len(self.data_buffer) == self.ds_finished_counter):
#drain any remaining messages in the buffer
self.drain()
self.signal_done()
else:
event = json.loads(message)
self.append(event[u'id'], event)
self.send_next()
def __len__(self):
"""buffer's length is same as internal map holding separate sorted arrays of events keyed by source id"""
return len(self.data_buffer)
def append(self, source_id, value):
"""add an event to the buffer for the source specified by source_id"""
self.data_buffer[source_id].append(value)
self.received_count += 1
def next(self):
"""Get the next message in chronological order"""
if(not(self.is_full() or self.draining)):
return
cur = None
earliest = None
for events in self.data_buffer.values():
if len(events) == 0:
continue
cur = events
if(earliest == None) or (cur[0]['dt'] <= earliest[0]['dt']):
earliest = cur
if(earliest != None):
return earliest.pop(0)
def is_full(self):
"""indicates whether the buffer has messages in buffer for all un-DONE sources"""
for events in self.data_buffer.values():
if (len(events) == 0):
return False
return True
def pending_messages(self):
"""returns the count of all events from all sources in the buffer"""
total = 0
for events in self.data_buffer.values():
total += len(events)
return total
def drain(self):
"""send all messages in the buffer"""
self.draining = True
while(self.pending_messages() > 0):
self.send_next()
def send_next(self):
"""send the (chronologically) next message in the buffer."""
if(not(self.is_full() or self.draining)):
return
event = self.next()
if(event != None):
self.feed_socket.send(json.dumps(event), self.zmq.NOBLOCK)
self.sent_count += 1
class MergedParallelBuffer(ParallelBuffer):
"""
Merges multiple streams of events into single messages.
"""
def __init__(self):
ParallelBuffer.__init__(self)
def open(self):
self.pull_socket, self.poller = self.bind_merge()
self.feed_socket = self.bind_result()
def next(self):
"""Get the next merged message from the feed buffer."""
if(not(self.is_full() or self.draining)):
return
#get the raw event from the passthrough transform.
result = self.data_buffer["PASSTHROUGH"].pop(0)['value']
for source, events in self.data_buffer.iteritems():
if(source == "PASSTHROUGH"):
continue
if(len(events) > 0):
cur = events.pop(0)
result[source] = cur['value']
return result
def get_id(self):
return "MERGE"
class BaseTransform(Component):
"""Top level execution entry point for the transform::
- connects to the feed socket to subscribe to events
- connets to the result socket (most oftened bound by a TransformsMerge) to PUSH transforms
- processes all messages received from feed, until DONE message received
- pushes all transforms
- sends DONE to result socket, closes all sockets and context
Parent class for feed transforms. Subclass and override transform
method to create a new derived value from the combined feed."""
def __init__(self, name):
Component.__init__(self)
self.state = {}
self.state['name'] = name
def get_id(self):
return self.state['name']
def open(self):
"""
Establishes zmq connections.
"""
#create the feed.
self.feed_socket, self.poller = self.connect_feed()
#create the result PUSH
self.result_socket = self.connect_merge()
def do_work(self):
"""
Loops until feed's DONE message is received:
- receive an event from the data feed
- call transform (subclass' method) on event
- send the transformed event
"""
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.feed_socket in socks and socks[self.feed_socket] == self.zmq.POLLIN:
message = self.feed_socket.recv()
if(self.is_done_message(message)):
self.signal_done()
return
event = json.loads(message)
cur_state = self.transform(event)
cur_state['dt'] = event['dt']
cur_state['id'] = self.state['name']
self.result_socket.send(json.dumps(cur_state), self.zmq.NOBLOCK)
def transform(self, event):
""" Must return the transformed value as a map with {name:"name of new transform", value: "value of new field"}
Transforms run in parallel and results are merged into a single map, so transform names must be unique.
Best practice is to use the self.state object initialized from the transform configuration, and only set the
transformed value:
self.state['value'] = transformed_value
"""
NotImplemented
class PassthroughTransform(BaseTransform):
def __init__(self):
BaseTransform.__init__(self, "PASSTHROUGH")
def transform(self, event):
return {'value':event}
class DataSource(Component):
"""
Baseclass for data sources. Subclass and implement send_all - usually this
means looping through all records in a store, converting to a dict, and
calling send(map).
"""
def __init__(self, source_id):
Component.__init__(self)
self.id = source_id
self.cur_event = None
def get_id(self):
return self.id
def open(self):
#create the data sink. Based on http://zguide.zeromq.org/py:tasksink2
self.data_socket = self.connect_data()
def send(self, event):
"""
event is expected to be a dict
sets id and type properties in the dict
sends to the data_socket.
"""
event['id'] = self.id
event['type'] = self.get_type()
self.data_socket.send(json.dumps(event))
def get_type(self):
raise NotImplemented
+81 -45
View File
@@ -9,7 +9,7 @@ from zipline.component import Component
class ComponentHost(Component):
"""
Component that can launch multiple sub-components, synchronize their start, and then wait for all
Components that can launch multiple sub-components, synchronize their start, and then wait for all
components to be finished.
"""
@@ -19,13 +19,15 @@ class ComponentHost(Component):
#workaround for defect in threaded use of strptime: http://bugs.python.org/issue11108
qutil.parse_date("2012/02/13-10:04:28.114")
self.components = {}
self.sync_register = {}
self.timeout = datetime.timedelta(seconds=5)
self.gevent_needed = gevent_needed
self.feed = ParallelBuffer()
self.merge = MergedParallelBuffer()
self.passthrough = PassthroughTransform()
self.gevent_needed = gevent_needed
self.controller = None
#register the feed and the merge
@@ -54,14 +56,18 @@ class ComponentHost(Component):
self.merge.add_source(component.get_id())
def unregister_component(self, component_id):
del(self.components[component_id])
del(self.sync_register[component_id])
del self.components[component_id]
del self.sync_register[component_id]
def setup_sync(self):
"""Start the sync server."""
"""
Start the sync server.
"""
qutil.LOGGER.debug("Connecting sync server.")
self.sync_socket = self.context.socket(self.zmq.REP)
self.sync_socket.bind(self.addresses['sync_address'])
self.poller = self.zmq.Poller()
self.poller.register(self.sync_socket, self.zmq.POLLIN)
self.sockets.append(self.sync_socket)
@@ -73,7 +79,8 @@ class ComponentHost(Component):
def is_timed_out(self):
cur_time = datetime.datetime.utcnow()
if(len(self.components) == 0):
if len(self.components) == 0:
qutil.LOGGER.info("Component register is empty.")
return True
for source, last_dt in self.sync_register.iteritems():
@@ -104,15 +111,18 @@ class ComponentHost(Component):
self.sync_socket.send('ack', self.zmq.NOBLOCK)
def launch_controller(self, controller):
NotImplemented
raise NotImplementedError
def launch_component(self, component):
NotImplemented
raise NotImplementedError
class ParallelBuffer(Component):
"""Connects to N PULL sockets, publishing all messages received to a PUB socket.
Published messages are guaranteed to be in chronological order based on message property dt.
Expects to be instantiated in one execution context (thread, process, etc) and run in another."""
"""
Connects to N PULL sockets, publishing all messages received to a PUB socket.
Published messages are guaranteed to be in chronological order based on message property dt.
Expects to be instantiated in one execution context (thread, process, etc) and run in another.
"""
def __init__(self):
Component.__init__(self)
@@ -130,8 +140,8 @@ class ParallelBuffer(Component):
self.data_buffer[source_id] = []
def open(self):
self.pull_socket, self.poller = self.bind_data()
self.feed_socket = self.bind_feed()
self.pull_socket, self.poller = self.bind_data()
self.feed_socket = self.bind_feed()
def do_work(self):
# wait for synchronization reply from the host
@@ -139,9 +149,9 @@ class ParallelBuffer(Component):
if self.pull_socket in socks and socks[self.pull_socket] == self.zmq.POLLIN:
message = self.pull_socket.recv()
if(self.is_done_message(message)):
if self.is_done_message(message):
self.ds_finished_counter += 1
if(len(self.data_buffer) == self.ds_finished_counter):
if len(self.data_buffer) == self.ds_finished_counter:
#drain any remaining messages in the buffer
self.drain()
self.signal_done()
@@ -151,17 +161,25 @@ class ParallelBuffer(Component):
self.send_next()
def __len__(self):
"""buffer's length is same as internal map holding separate sorted arrays of events keyed by source id"""
"""
Buffer's length is same as internal map holding separate
sorted arrays of events keyed by source id.
"""
return len(self.data_buffer)
def append(self, source_id, value):
"""add an event to the buffer for the source specified by source_id"""
"""
Add an event to the buffer for the source specified by
source_id.
"""
self.data_buffer[source_id].append(value)
self.received_count += 1
def next(self):
"""Get the next message in chronological order"""
if(not(self.is_full() or self.draining)):
"""
Get the next message in chronological order.
"""
if not(self.is_full() or self.draining):
return
cur = None
@@ -170,34 +188,44 @@ class ParallelBuffer(Component):
if len(events) == 0:
continue
cur = events
if(earliest == None) or (cur[0]['dt'] <= earliest[0]['dt']):
if (earliest == None) or (cur[0]['dt'] <= earliest[0]['dt']):
earliest = cur
if(earliest != None):
if earliest != None:
return earliest.pop(0)
def is_full(self):
"""indicates whether the buffer has messages in buffer for all un-DONE sources"""
"""
Indicates whether the buffer has messages in buffer for
all un-DONE sources.
"""
for events in self.data_buffer.values():
if (len(events) == 0):
if len(events) == 0:
return False
return True
def pending_messages(self):
"""returns the count of all events from all sources in the buffer"""
"""
Returns the count of all events from all sources in the
buffer.
"""
total = 0
for events in self.data_buffer.values():
total += len(events)
return total
def drain(self):
"""send all messages in the buffer"""
"""
Send all messages in the buffer
"""
self.draining = True
while(self.pending_messages() > 0):
self.send_next()
def send_next(self):
"""send the (chronologically) next message in the buffer."""
"""
Send the (chronologically) next message in the buffer.
"""
if(not(self.is_full() or self.draining)):
return
@@ -216,8 +244,8 @@ class MergedParallelBuffer(ParallelBuffer):
ParallelBuffer.__init__(self)
def open(self):
self.pull_socket, self.poller = self.bind_merge()
self.feed_socket = self.bind_result()
self.pull_socket, self.poller = self.bind_merge()
self.feed_socket = self.bind_result()
def next(self):
"""Get the next merged message from the feed buffer."""
@@ -227,9 +255,9 @@ class MergedParallelBuffer(ParallelBuffer):
#get the raw event from the passthrough transform.
result = self.data_buffer["PASSTHROUGH"].pop(0)['value']
for source, events in self.data_buffer.iteritems():
if(source == "PASSTHROUGH"):
if source == "PASSTHROUGH":
continue
if(len(events) > 0):
if len(events) > 0:
cur = events.pop(0)
result[source] = cur['value']
return result
@@ -252,8 +280,8 @@ class BaseTransform(Component):
def __init__(self, name):
Component.__init__(self)
self.state = {}
self.state['name'] = name
self.state = {}
self.state['name'] = name
def get_id(self):
return self.state['name']
@@ -277,9 +305,10 @@ class BaseTransform(Component):
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.feed_socket in socks and socks[self.feed_socket] == self.zmq.POLLIN:
message = self.feed_socket.recv()
if(self.is_done_message(message)):
if self.is_done_message(message):
self.signal_done()
return
event = json.loads(message)
cur_state = self.transform(event)
cur_state['dt'] = event['dt']
@@ -287,13 +316,19 @@ class BaseTransform(Component):
self.result_socket.send(json.dumps(cur_state), self.zmq.NOBLOCK)
def transform(self, event):
""" Must return the transformed value as a map with {name:"name of new transform", value: "value of new field"}
Transforms run in parallel and results are merged into a single map, so transform names must be unique.
Best practice is to use the self.state object initialized from the transform configuration, and only set the
transformed value:
self.state['value'] = transformed_value
"""
NotImplemented
Must return the transformed value as a map with::
{name:"name of new transform", value: "value of new field"}
Transforms run in parallel and results are merged into a single map, so transform names must be unique.
Best practice is to use the self.state object initialized from the transform configuration, and only set the
transformed value::
self.state['value'] = transformed_value
"""
raise NotImplementedError
class PassthroughTransform(BaseTransform):
@@ -303,6 +338,7 @@ class PassthroughTransform(BaseTransform):
def transform(self, event):
return {'value':event}
class DataSource(Component):
"""
Baseclass for data sources. Subclass and implement send_all - usually this
@@ -311,8 +347,8 @@ class DataSource(Component):
"""
def __init__(self, source_id):
Component.__init__(self)
self.id = source_id
self.cur_event = None
self.id = source_id
self.cur_event = None
def get_id(self):
return self.id
@@ -323,13 +359,13 @@ class DataSource(Component):
def send(self, event):
"""
event is expected to be a dict
sets id and type properties in the dict
sends to the data_socket.
event is expected to be a dict
sets id and type properties in the dict
sends to the data_socket.
"""
event['id'] = self.id
event['type'] = self.get_type()
self.data_socket.send(json.dumps(event))
def get_type(self):
raise NotImplemented
raise NotImplementedError
+1 -2
View File
@@ -1,4 +1,3 @@
import msgpack
#import msgpack
#import ujson
#import ultrajson_numpy
+87
View File
@@ -0,0 +1,87 @@
"""
Dummy simulator backported from Qexec for development on Zipline.
"""
import threading
import mock
from unittest2 import TestCase
from zipline.test.test_messaging import SimulatorTestCase
from zipline.monitor import Controller
from zipline.messaging import ComponentHost
import zipline.util as qutil
class DummyAllocator(object):
def __init__(self, ns):
self.idx = 0
self.sockets = [
'tcp://127.0.0.1:%s' % (10000 + n)
for n in xrange(ns)
]
def lease(self, n):
sockets = self.sockets[self.idx:self.idx+n]
self.idx += n
return sockets
def reaquire(self, *conn):
pass
class SimulatorBase(ComponentHost):
"""
Simulator coordinates the launch and communication of source, feed, transform, and merge components.
"""
def __init__(self, addresses, gevent_needed=False):
"""
"""
ComponentHost.__init__(self, addresses, gevent_needed)
def simulate(self):
self.run()
def get_id(self):
return "Simulator"
class ThreadSimulator(SimulatorBase):
def __init__(self, addresses):
SimulatorBase.__init__(self, addresses)
def launch_controller(self):
thread = threading.Thread(target=self.controller.run)
thread.start()
self.cuc = thread
return thread
def launch_component(self, component):
thread = threading.Thread(target=component.run)
thread.start()
return thread
class ThreadPoolExecutor(SimulatorTestCase, TestCase):
allocator = DummyAllocator(100)
def setup_logging(self):
qutil.configure_logging()
# lazy import by design
self.logger = mock.Mock()
def setup_allocator(self):
pass
def get_simulator(self, addresses):
return ThreadSimulator(addresses)
def get_controller(self):
# Allocate two more sockets
controller_sockets = self.allocate_sockets(2)
return Controller(
controller_sockets[0],
controller_sockets[1],
logging = self.logger,
)
+2 -2
View File
@@ -150,7 +150,7 @@ class SimulatorTestCase(object):
sim.register_controller( con )
sim.register_components([ret1, ret2, client])
# Simulation
# Simulation
# ----------
sim.simulate()
@@ -197,7 +197,7 @@ class SimulatorTestCase(object):
sim.register_components([ret1, ret2, mavg1, mavg2, client])
sim.register_controller( con )
# Simulation
# Simulation
# ----------
sim.simulate()
+13 -13
View File
@@ -2,31 +2,31 @@
Transformations for common technical indicators.
TODO: add MACD transform
TODO: add trailing stop
"""
import datetime
from zipline.messaging import BaseTransform
import zipline.util as qutil
class MovingAverage(BaseTransform):
"""
Calculate a unweighted moving average for props['sid'] security
TODO: add sid -> mvavg dict.
Calculate a unweighted moving average for props['sid'] security
TODO: add sid -> mvavg dict.
"""
def __init__(self, name, days):
def __init__(self, name, days):
BaseTransform.__init__(self, name)
self.events = []
self.current_total = 0
self.window = datetime.timedelta(days = days)
self.window = datetime.timedelta(days = days)
def transform(self, event):
"""Update the moving average with the latest data point."""
self.events.append(event)
self.current_total += event['price']
event_date = qutil.parse_date(event['dt'])
index = 0
for cur_event in self.events:
cur_date = qutil.parse_date(cur_event['dt'])
@@ -36,12 +36,12 @@ class MovingAverage(BaseTransform):
index += 1
else:
break
if(len(self.events) == 0):
return 0.0
self.average = self.current_total/len(self.events)
self.state['value'] = self.average
return self.state