mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 03:06:53 +08:00
send sockets have zero linger, sync object now serves as heartbeat, components die on error, all components error on heartbeat timeout
This commit is contained in:
+141
-77
@@ -6,6 +6,7 @@ import zmq
|
||||
import json
|
||||
import copy
|
||||
import threading
|
||||
import datetime
|
||||
|
||||
import qsim.util as qutil
|
||||
import qsim.messaging as qmsg
|
||||
@@ -15,7 +16,7 @@ class Simulator(object):
|
||||
Simulator coordinates the launch and communication of source, feed, transform, and merge components.
|
||||
"""
|
||||
|
||||
def __init__(self, sources, transforms, client):
|
||||
def __init__(self, sources, transforms, client, feed=None, merge=None):
|
||||
"""
|
||||
"""
|
||||
self.sources = sources
|
||||
@@ -25,7 +26,7 @@ class Simulator(object):
|
||||
self.feed = None
|
||||
self.context = None
|
||||
self.sync_context = None
|
||||
self.syncservice = None
|
||||
self.sync_socket = None
|
||||
self.sync_register = {}
|
||||
self.sync_address = "tcp://127.0.0.1:{port}".format(port=10100)
|
||||
self.data_address = "tcp://127.0.0.1:{port}".format(port=10101)
|
||||
@@ -33,29 +34,45 @@ class Simulator(object):
|
||||
self.merge_address = "tcp://127.0.0.1:{port}".format(port=10103)
|
||||
self.result_address = "tcp://127.0.0.1:{port}".format(port=10104)
|
||||
|
||||
self.timeout = datetime.timedelta(seconds=1)
|
||||
|
||||
#workaround for defect in threaded use of strptime: http://bugs.python.org/issue11108
|
||||
qutil.parse_date("2012/02/13-10:04:28.114")
|
||||
|
||||
if(feed == None):
|
||||
self.feed = DataFeed(self.sources.keys(), self.data_address, self.feed_address, qmsg.Sync(self,"DataFeed"))
|
||||
else:
|
||||
self.feed = feed
|
||||
|
||||
if(merge == None):
|
||||
#connect merge to feed, set expected transforms
|
||||
self.merge = TransformsMerge(self.feed_address,
|
||||
self.merge_address,
|
||||
self.result_address,
|
||||
qmsg.Sync(self,"TransformsMerge"),
|
||||
self.transforms.keys())
|
||||
else:
|
||||
self.merge = merge
|
||||
|
||||
def simulate(self):
|
||||
self.feed = DataFeed(self.sources.keys(), self.data_address, self.feed_address, qmsg.Sync(self,"DataFeed"))
|
||||
#launch the feed
|
||||
self.launch_component("DataFeed", self.feed)
|
||||
|
||||
#launch the data sources
|
||||
for name, data_source in self.sources.iteritems():
|
||||
data_source.data_address = self.data_address
|
||||
data_source.sync = qmsg.Sync(self, str(data_source.source_id))
|
||||
self.launch_component(name, data_source)
|
||||
qutil.LOGGER.info("datasources processes launched")
|
||||
|
||||
#connect all the transforms to the feed and merge
|
||||
#connect all the transforms to the feed and merge, launch each
|
||||
for name, transform in self.transforms.iteritems():
|
||||
transform.feed_address = self.feed_address #connect transform to receive feed.
|
||||
transform.merge_address = self.merge_address #connect transform to push results to merge
|
||||
transform.sync = qmsg.Sync(self, name) #synchronize the transform against this simulation.
|
||||
self.launch_component(name, transform) #start transforms
|
||||
|
||||
#connect merge to feed, set expected transforms
|
||||
self.merge = TransformsMerge(self.feed_address,
|
||||
self.merge_address,
|
||||
self.result_address,
|
||||
qmsg.Sync(self,"TransformsMerge"),
|
||||
self.transforms.keys())
|
||||
self.launch_component(name, transform) #start transforms
|
||||
|
||||
#launch merge
|
||||
self.launch_component("transforms merge", self.merge)
|
||||
qutil.LOGGER.info("transform processes launched")
|
||||
|
||||
@@ -66,7 +83,7 @@ class Simulator(object):
|
||||
qutil.LOGGER.info("client process launched")
|
||||
|
||||
self.sync_components()
|
||||
client_proc.join() #wait for client to complete processing
|
||||
#client_proc.join() #wait for client to complete processing
|
||||
|
||||
def launch_component(self, name, component):
|
||||
qutil.LOGGER.info("starting {name}".format(name=name))
|
||||
@@ -81,32 +98,54 @@ class Simulator(object):
|
||||
return proc
|
||||
|
||||
def register_sync(self, sync_id):
|
||||
self.sync_register[sync_id] = "UNCONFIRMED"
|
||||
|
||||
def registration_complete(self):
|
||||
for status in self.sync_register.values():
|
||||
if status == "UNCONFIRMED":
|
||||
return False
|
||||
|
||||
return True
|
||||
self.sync_register[sync_id] = datetime.datetime.utcnow()
|
||||
|
||||
def unregister_sync(self, sync_id):
|
||||
del(self.sync_register[sync_id])
|
||||
|
||||
def is_timed_out(self):
|
||||
cur_time = datetime.datetime.utcnow()
|
||||
if(len(self.sync_register) == 0):
|
||||
qutil.LOGGER.info("**********Simulator sync register is empty.")
|
||||
return True
|
||||
for source, last_dt in self.sync_register.iteritems():
|
||||
if((cur_time - last_dt) > self.timeout):
|
||||
qutil.LOGGER.info("Time out for {source}".format(source=source))
|
||||
return True
|
||||
return False
|
||||
|
||||
def sync_components(self):
|
||||
# Socket to receive signals
|
||||
self.context = zmq.Context()
|
||||
qutil.LOGGER.info("waiting for all datasources and clients to be ready")
|
||||
self.syncservice = self.context.socket(zmq.REP)
|
||||
self.syncservice.bind(self.sync_address)
|
||||
self.sync_socket = self.context.socket(zmq.REP)
|
||||
self.sync_socket.bind(self.sync_address)
|
||||
self.sync_socket.setsockopt(zmq.LINGER,0)
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.sync_socket, zmq.POLLIN)
|
||||
|
||||
while not self.registration_complete():
|
||||
while not self.is_timed_out():
|
||||
# wait for synchronization request
|
||||
msg = self.syncservice.recv()
|
||||
self.sync_register[msg] = "CONFIRMED"
|
||||
#qutil.LOGGER.info("confirmed {id}".format(id=msg))
|
||||
# send synchronization reply
|
||||
self.syncservice.send('CONFIRMED')
|
||||
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
|
||||
|
||||
self.syncservice.close()
|
||||
qutil.LOGGER.info("sync'd all datasources and clients")
|
||||
if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN:
|
||||
try:
|
||||
msg = self.sync_socket.recv()
|
||||
parts = msg.split(':')
|
||||
sync_id = parts[0]
|
||||
status = parts[1]
|
||||
if(status == "DONE"):
|
||||
self.unregister_sync(sync_id)
|
||||
else:
|
||||
self.sync_register[sync_id] = datetime.datetime.utcnow()
|
||||
#qutil.LOGGER.info("confirmed {id}".format(id=msg))
|
||||
# send synchronization reply
|
||||
self.sync_socket.send('ack')
|
||||
except:
|
||||
continue
|
||||
|
||||
self.sync_socket.close()
|
||||
qutil.LOGGER.info("simulator heartbeat stopped.")
|
||||
|
||||
|
||||
|
||||
@@ -123,30 +162,43 @@ class DataFeed(object):
|
||||
self.feed_socket = None
|
||||
self.data_socket = None
|
||||
self.context = None
|
||||
self.poller = None
|
||||
|
||||
def run(self):
|
||||
def open(self):
|
||||
# Prepare our context and sockets
|
||||
try:
|
||||
self.context = zmq.Context()
|
||||
self.context = zmq.Context()
|
||||
#create the data sink. Based on http://zguide.zeromq.org/py:tasksink2
|
||||
#see: http://zguide.zeromq.org/py:taskwork2
|
||||
self.data_socket = self.context.socket(zmq.PULL)
|
||||
self.data_socket.bind(self.data_address)
|
||||
|
||||
#create the feed
|
||||
self.feed_socket = self.context.socket(zmq.PUB)
|
||||
self.feed_socket.bind(self.feed_address)
|
||||
self.feed_socket.setsockopt(zmq.LINGER,0)
|
||||
|
||||
ds_finished_counter = 0
|
||||
|
||||
#create the data sink. Based on http://zguide.zeromq.org/py:tasksink2
|
||||
#see: http://zguide.zeromq.org/py:taskwork2
|
||||
self.data_socket = self.context.socket(zmq.PULL)
|
||||
self.data_socket.bind(self.data_address)
|
||||
self.data_buffer.out_socket = self.feed_socket
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.data_socket, zmq.POLLIN)
|
||||
|
||||
#create the feed
|
||||
self.feed_socket = self.context.socket(zmq.PUB)
|
||||
self.feed_socket.bind(self.feed_address)
|
||||
self.sync.open()
|
||||
|
||||
self.data_buffer.out_socket = self.feed_socket
|
||||
def close(self):
|
||||
self.data_socket.close()
|
||||
self.feed_socket.close()
|
||||
self.sync.close()
|
||||
self.context.term()
|
||||
|
||||
self.sync.confirm()
|
||||
qutil.LOGGER.info("entering feed loop on {addr}".format(addr=self.data_address))
|
||||
|
||||
while True:
|
||||
def handle_all(self):
|
||||
qutil.LOGGER.info("entering feed loop on {addr}".format(addr=self.data_address))
|
||||
ds_finished_counter = 0
|
||||
while self.sync.confirm():
|
||||
# wait for synchronization reply from the host
|
||||
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
|
||||
|
||||
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
||||
message = self.data_socket.recv()
|
||||
|
||||
event = json.loads(message)
|
||||
if(event["type"] == "DONE"):
|
||||
ds_finished_counter += 1
|
||||
@@ -156,20 +208,23 @@ class DataFeed(object):
|
||||
self.data_buffer.append(event[u's'], event)
|
||||
self.data_buffer.send_next()
|
||||
|
||||
|
||||
#drain any remaining messages in the buffer
|
||||
self.data_buffer.drain()
|
||||
|
||||
#send the DONE message
|
||||
self.feed_socket.send("DONE")
|
||||
qutil.LOGGER.info("received {n} messages, sent {m} messages".format(n=self.data_buffer.received_count,
|
||||
m=self.data_buffer.sent_count))
|
||||
#drain any remaining messages in the buffer
|
||||
self.data_buffer.drain()
|
||||
|
||||
#send the DONE message
|
||||
self.feed_socket.send("DONE")
|
||||
qutil.LOGGER.info("received {n} messages, sent {m} messages".format(n=self.data_buffer.received_count,
|
||||
m=self.data_buffer.sent_count))
|
||||
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self.open()
|
||||
self.handle_all()
|
||||
except:
|
||||
qutil.LOGGER.exception("Exception in Feed, attempting to close.")
|
||||
finally:
|
||||
self.data_socket.close()
|
||||
self.feed_socket.close()
|
||||
self.context.term()
|
||||
self.close()
|
||||
|
||||
|
||||
|
||||
@@ -204,7 +259,7 @@ class BaseTransform(object):
|
||||
self.open()
|
||||
self.process_all()
|
||||
except:
|
||||
qutil.LOGGER.exception("Exception during merge processing, attempting to close merge.")
|
||||
qutil.LOGGER.exception("Exception during transform processing, attempting to close merge.")
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
@@ -221,9 +276,15 @@ class BaseTransform(object):
|
||||
self.feed_socket.connect(self.feed_address)
|
||||
self.feed_socket.setsockopt(zmq.SUBSCRIBE,'')
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.feed_socket, zmq.POLLIN)
|
||||
|
||||
#create the result PUSH
|
||||
self.result_socket = self.context.socket(zmq.PUSH)
|
||||
self.result_socket.connect(self.merge_address)
|
||||
self.result_socket.setsockopt(zmq.LINGER,0)
|
||||
|
||||
self.sync.open()
|
||||
|
||||
def process_all(self):
|
||||
"""
|
||||
@@ -233,21 +294,22 @@ class BaseTransform(object):
|
||||
- send the transformed event
|
||||
"""
|
||||
qutil.LOGGER.info("starting {name} event loop".format(name = self.state['name']))
|
||||
self.sync.confirm()
|
||||
|
||||
while True:
|
||||
message = self.feed_socket.recv()
|
||||
if(message == "DONE"):
|
||||
qutil.LOGGER.info("{name} received the Done message from the feed".format(name=self.state['name']))
|
||||
self.result_socket.send("DONE")
|
||||
break
|
||||
self.received_count += 1
|
||||
event = json.loads(message)
|
||||
cur_state = self.transform(event)
|
||||
cur_state['dt'] = event['dt']
|
||||
cur_state['name'] = self.state['name']
|
||||
self.result_socket.send(json.dumps(cur_state))
|
||||
self.sent_count += 1
|
||||
while self.sync.confirm():
|
||||
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
|
||||
if self.feed_socket in socks and socks[self.feed_socket] == zmq.POLLIN:
|
||||
message = self.feed_socket.recv()
|
||||
if(message == "DONE"):
|
||||
qutil.LOGGER.info("{name} received the Done message from the feed".format(name=self.state['name']))
|
||||
self.result_socket.send("DONE")
|
||||
break
|
||||
self.received_count += 1
|
||||
event = json.loads(message)
|
||||
cur_state = self.transform(event)
|
||||
cur_state['dt'] = event['dt']
|
||||
cur_state['name'] = self.state['name']
|
||||
self.result_socket.send(json.dumps(cur_state))
|
||||
self.sent_count += 1
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
@@ -260,6 +322,7 @@ class BaseTransform(object):
|
||||
|
||||
self.feed_socket.close()
|
||||
self.result_socket.close()
|
||||
self.sync.close()
|
||||
self.context.term()
|
||||
|
||||
def transform(self, event):
|
||||
@@ -320,6 +383,7 @@ class TransformsMerge(object):
|
||||
#create the result PUSH
|
||||
self.result_socket = self.context.socket(zmq.PUSH)
|
||||
self.result_socket.bind(self.result_address)
|
||||
self.result_socket.setsockopt(zmq.LINGER,0)
|
||||
|
||||
#create the transform PULL.
|
||||
self.transform_socket = self.context.socket(zmq.PULL)
|
||||
@@ -331,7 +395,7 @@ class TransformsMerge(object):
|
||||
self.poller.register(self.feed_socket, zmq.POLLIN)
|
||||
self.poller.register(self.transform_socket, zmq.POLLIN)
|
||||
|
||||
self.sync.confirm()
|
||||
self.sync.open()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
@@ -350,8 +414,8 @@ class TransformsMerge(object):
|
||||
sent to the result socket.
|
||||
"""
|
||||
done_count = 0
|
||||
while True:
|
||||
socks = dict(self.poller.poll())
|
||||
while self.sync.confirm():
|
||||
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
|
||||
|
||||
if self.feed_socket in socks and socks[self.feed_socket] == zmq.POLLIN:
|
||||
message = self.feed_socket.recv()
|
||||
|
||||
+36
-14
@@ -108,22 +108,44 @@ class Sync(object):
|
||||
to delay the start of the host until initial setup is complete."""
|
||||
|
||||
def __init__(self, host, name):
|
||||
self.host = host
|
||||
self.sync_id = "{name}-{id}".format(name=name, id=uuid.uuid1())
|
||||
self.host = host
|
||||
self.sync_id = "{name}-{id}".format(name=name, id=uuid.uuid1())
|
||||
self.context = None
|
||||
self.sync_socket = None
|
||||
self.poller = None
|
||||
self.host.register_sync(self.sync_id)
|
||||
|
||||
#qutil.LOGGER.info("registered {id} with host".format(id=self.sync_id))
|
||||
|
||||
def open(self):
|
||||
self.context = zmq.Context()
|
||||
#synchronize with host
|
||||
self.sync_socket = self.context.socket(zmq.REQ)
|
||||
self.sync_socket.connect(self.host.sync_address)
|
||||
self.sync_socket.setsockopt(zmq.LINGER,0)
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.sync_socket, zmq.POLLIN)
|
||||
|
||||
def confirm(self):
|
||||
"""Confirm readiness with the Host."""
|
||||
context = zmq.Context()
|
||||
#synchronize with host
|
||||
sync_socket = context.socket(zmq.REQ)
|
||||
sync_socket.connect(self.host.sync_address)
|
||||
# send a synchronization request to the host
|
||||
sync_socket.send(self.sync_id)
|
||||
# wait for synchronization reply from the host
|
||||
sync_socket.recv()
|
||||
sync_socket.close()
|
||||
context.term()
|
||||
qutil.LOGGER.info("sync'd host from {id}".format(id = self.sync_id))
|
||||
|
||||
try:
|
||||
# send a synchronization request to the host
|
||||
self.sync_socket.send(self.sync_id + ":RUNNING", zmq.NOBLOCK)
|
||||
# wait for synchronization reply from the host
|
||||
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
|
||||
|
||||
if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN:
|
||||
message = self.sync_socket.recv()
|
||||
return True
|
||||
except:
|
||||
qutil.LOGGER.exception("exception in confirmation for {source}. Exiting.".format(source=self.sync_id))
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.sync_socket.send(self.sync_id + ":DONE", zmq.NOBLOCK)
|
||||
self.sync_socket.close()
|
||||
self.context.term()
|
||||
except:
|
||||
pass #just don't want to error out on closing
|
||||
|
||||
+30
-11
@@ -33,14 +33,19 @@ class DataSource(object):
|
||||
#create the data sink. Based on http://zguide.zeromq.org/py:tasksink2
|
||||
self.data_socket = self.context.socket(zmq.PUSH)
|
||||
self.data_socket.connect(self.data_address)
|
||||
self.data_socket.setsockopt(zmq.LINGER,0)
|
||||
|
||||
self.sync.confirm()
|
||||
self.sync.open()
|
||||
|
||||
def run(self):
|
||||
"""Fully execute this datasource."""
|
||||
self.open()
|
||||
self.send_all()
|
||||
self.close()
|
||||
try:
|
||||
self.open()
|
||||
self.send_all()
|
||||
except:
|
||||
qutil.LOGGER.info("Exception running datasource.")
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def send_all(self):
|
||||
"""Subclasses must implement this method."""
|
||||
@@ -52,20 +57,35 @@ class DataSource(object):
|
||||
sets source_id and type properties in the dict
|
||||
sends to the data_socket.
|
||||
"""
|
||||
self.sync.confirm()
|
||||
event['s'] = self.source_id
|
||||
event['type'] = 'event'
|
||||
self.data_socket.send(json.dumps(event))
|
||||
self.data_socket.send(json.dumps(event), zmq.NOBLOCK)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the zmq context and sockets.
|
||||
"""
|
||||
done_msg = {}
|
||||
done_msg['type'] = 'DONE'
|
||||
done_msg['s'] = self.source_id
|
||||
self.data_socket.send(json.dumps(done_msg))
|
||||
qutil.LOGGER.info("sending DONE message.")
|
||||
try:
|
||||
done_msg = {}
|
||||
done_msg['type'] = 'DONE'
|
||||
done_msg['s'] = self.source_id
|
||||
self.data_socket.send(json.dumps(done_msg), zmq.NOBLOCK)
|
||||
except:
|
||||
qutil.LOGGER.exception("failed to send DONE message")
|
||||
pass #continue with the closing.
|
||||
|
||||
qutil.LOGGER.info("closing data socket")
|
||||
self.data_socket.close()
|
||||
self.context.term()
|
||||
qutil.LOGGER.info("closing sync")
|
||||
self.sync.close()
|
||||
qutil.LOGGER.info("closing context")
|
||||
try:
|
||||
self.context.term()
|
||||
qutil.LOGGER.info("done")
|
||||
except:
|
||||
qutil.LOGGER.exception("error closing context")
|
||||
qutil.LOGGER.info("finished processing data source")
|
||||
|
||||
class RandomEquityTrades(DataSource):
|
||||
@@ -89,7 +109,6 @@ class RandomEquityTrades(DataSource):
|
||||
'volume':random.randrange(100,10000,100)}
|
||||
self.send(event)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
+21
-15
@@ -25,32 +25,38 @@ class TestClient(object):
|
||||
qutil.LOGGER.info("connecting to {address}".format(address=self.address))
|
||||
self.data_feed.connect(self.address)
|
||||
|
||||
self.sync.confirm()
|
||||
self.sync.open()
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.data_feed, zmq.POLLIN)
|
||||
|
||||
qutil.LOGGER.info("Starting the client loop")
|
||||
|
||||
prev_dt = None
|
||||
while True:
|
||||
msg = self.data_feed.recv()
|
||||
if(msg == "DONE"):
|
||||
qutil.LOGGER.info("DONE!")
|
||||
break
|
||||
self.received_count += 1
|
||||
event = json.loads(msg)
|
||||
if(prev_dt != None):
|
||||
if(not event['dt'] >= prev_dt):
|
||||
raise Exception("Message out of order: {date} after {prev}".format(date=event['dt'], prev=prev_dt))
|
||||
while self.sync.confirm():
|
||||
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
|
||||
if self.data_feed in socks and socks[self.data_feed] == zmq.POLLIN:
|
||||
msg = self.data_feed.recv()
|
||||
if(msg == "DONE"):
|
||||
qutil.LOGGER.info("DONE!")
|
||||
break
|
||||
self.received_count += 1
|
||||
event = json.loads(msg)
|
||||
if(prev_dt != None):
|
||||
if(not event['dt'] >= prev_dt):
|
||||
raise Exception("Message out of order: {date} after {prev}".format(date=event['dt'], prev=prev_dt))
|
||||
|
||||
prev_dt = event['dt']
|
||||
if(self.received_count % 100 == 0):
|
||||
qutil.LOGGER.info("received {n} messages".format(n=self.received_count))
|
||||
prev_dt = event['dt']
|
||||
if(self.received_count % 100 == 0):
|
||||
qutil.LOGGER.info("received {n} messages".format(n=self.received_count))
|
||||
|
||||
qutil.LOGGER.info("received {n} messages".format(n=self.received_count))
|
||||
except:
|
||||
self.error = True
|
||||
qutil.LOGGER.exception("Error in test client.")
|
||||
qutil.LOGGER.exception("**********************Error in test client.")
|
||||
finally:
|
||||
self.data_feed.close()
|
||||
self.sync.close()
|
||||
self.context.term()
|
||||
|
||||
self.utest.assertEqual(self.expected_msg_count, self.received_count,
|
||||
|
||||
@@ -5,11 +5,13 @@ Test suite for the messaging infrastructure of QSim.
|
||||
|
||||
import unittest2 as unittest
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
from qsim.core import Simulator
|
||||
from qsim.core import Simulator, DataFeed
|
||||
from qsim.transforms.technical import MovingAverage
|
||||
from qsim.sources import RandomEquityTrades
|
||||
import qsim.util as qutil
|
||||
import qsim.messaging as qmsg
|
||||
|
||||
from qsim.test.client import TestClient
|
||||
|
||||
@@ -56,3 +58,26 @@ class MessagingTestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual(sim.feed.data_buffer.pending_messages(), 0, "The feed should be drained of all messages.")
|
||||
|
||||
def test_zerror_in_feed(self):
|
||||
ret1 = RandomEquityTrades(133, "ret1", 400)
|
||||
ret2 = RandomEquityTrades(134, "ret2", 400)
|
||||
sources = {"ret1":ret1, "ret2":ret2}
|
||||
mavg1 = MovingAverage("mavg1", 30)
|
||||
mavg2 = MovingAverage("mavg2", 60)
|
||||
transforms = {"mavg1":mavg1, "mavg2":mavg2}
|
||||
client = TestClient(self, expected_msg_count=0)
|
||||
sim = Simulator(sources, transforms, client)
|
||||
sim.feed = DataFeedErr(sources.keys(), sim.data_address, sim.feed_address, qmsg.Sync(sim, "DataFeedErrorGenerator"))
|
||||
sim.simulate()
|
||||
|
||||
class DataFeedErr(DataFeed):
|
||||
"""Helper class for testing, simulates exceptions inside the DataFeed"""
|
||||
|
||||
def __init__(self, source_list, data_address, feed_address, sync):
|
||||
DataFeed.__init__(self, source_list, data_address, feed_address, sync)
|
||||
|
||||
def handle_all(self):
|
||||
#time.sleep(1000)
|
||||
raise Exception("simulated error in data feed from test helper")
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ TODO: add trailing stop
|
||||
"""
|
||||
import datetime
|
||||
from qsim.core import BaseTransform
|
||||
import qsim.util as qutil
|
||||
|
||||
class MovingAverage(BaseTransform):
|
||||
"""
|
||||
@@ -15,31 +16,32 @@ class MovingAverage(BaseTransform):
|
||||
|
||||
def __init__(self, name, days):
|
||||
BaseTransform.__init__(self, name)
|
||||
self.events = []
|
||||
|
||||
self.window = datetime.timedelta(days = days)
|
||||
|
||||
|
||||
|
||||
self.events = []
|
||||
self.current_total = 0
|
||||
self.window = datetime.timedelta(days = days)
|
||||
|
||||
def transform(self, event):
|
||||
"""Update the moving average with the latest data point."""
|
||||
|
||||
#self.events.append(event)
|
||||
self.events.append(event)
|
||||
self.current_total += event['price']
|
||||
event_date = qutil.parse_date(event['dt'])
|
||||
|
||||
#filter the event list to the window length.
|
||||
#self.events = [x for x in self.events if (qutil.parse_date(x['dt']) - qutil.parse_date(event['dt'])) <= self.window]
|
||||
index = 0
|
||||
for cur_event in self.events:
|
||||
cur_date = qutil.parse_date(cur_event['dt'])
|
||||
if(cur_date - event_date):
|
||||
self.events.pop(index)
|
||||
self.current_total -= cur_event['price']
|
||||
index += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if(len(self.events) == 0):
|
||||
return 0.0
|
||||
|
||||
self.average = self.current_total/len(self.events)
|
||||
|
||||
#if(len(self.events) == 0):
|
||||
# return 0.0
|
||||
|
||||
#total = 0.0
|
||||
#for event in self.events:
|
||||
# total += event['price']
|
||||
|
||||
#self.average = total/len(self.events)
|
||||
|
||||
#self.state['value'] = self.average
|
||||
self.state['value'] = 10
|
||||
self.state['value'] = self.average
|
||||
return self.state
|
||||
|
||||
+1
-1
@@ -15,7 +15,7 @@ def parse_date(dt_str):
|
||||
if(dt_str == None):
|
||||
return None
|
||||
parts = dt_str.split(".")
|
||||
dt = datetime.datetime.strptime(parts[0], '%Y/%m/%d-%H:%M:%S').replace(microsecond=int(parts[1]+"000"), tzinfo = pytz.utc)
|
||||
dt = datetime.datetime.strptime(parts[0], '%Y/%m/%d-%H:%M:%S').replace(microsecond=int(parts[1]+"000")).replace(tzinfo = pytz.utc)
|
||||
return dt
|
||||
|
||||
def format_date(dt):
|
||||
|
||||
Reference in New Issue
Block a user