refactoring tests to combine common code into factory for creation of test data sources. also created new zipline/lines module, which will

hold classes that instantiate entire zipline topologies.
This commit is contained in:
fawce
2012-03-17 23:15:38 -04:00
parent f7aa5c7c06
commit dcc471cc93
5 changed files with 285 additions and 163 deletions
+5 -3
View File
@@ -185,7 +185,7 @@ class PerformanceTracker():
def to_dict(self):
"""
Creates a dictionary representing the state of this tracker.
Returns a dict object of the form:
Returns a dict object of the form described in header comments.
"""
returns_list = [x.to_dict() for x in self.returns]
@@ -295,8 +295,8 @@ class PerformanceTracker():
#if self.result_stream:
## TODO: proper framing
#self.result_stream.send_pyobj(self.risk_report.to_dict())
self.result_stream.send_pyobj(None)
if self.result_stream:
self.result_stream.send_pyobj(None)
def round_to_nearest(self, x, base=5):
return int(base * round(float(x)/base))
@@ -368,6 +368,8 @@ class PerformancePeriod():
#cash balance at start of period
self.starting_cash = starting_cash
self.ending_cash = starting_cash
self.calculate_performance()
def calculate_performance(self):
self.ending_value = self.calculate_positions_value()
+1
View File
@@ -396,6 +396,7 @@ class TradingEnvironment(object):
self.period_start = period_start
self.period_end = period_end
self.capital_base = capital_base
for bm in benchmark_returns:
self.trading_days.append(bm.date)
self.trading_day_map[bm.date] = bm
+172
View File
@@ -0,0 +1,172 @@
"""
Ziplines are composed of multiple components connected by asynchronous
messaging. All ziplines follow a general topology of parallel sources,
datetimestamp serialization, parallel transformations, and finally sinks.
Furthermore, many ziplines have common needs. For example, all trade
simulations require a
:py:class:`~zipline.finance.trading.TradeSimulationClient`, an
:py:class:`~zipline.finance.trading.OrderSource`, and a
:py:class:`~zipline.finance.trading.TransactionSimulator` (a transform).
To establish best practices and minimize code replication, the lines module
provides complete zipline topologies. You can extend any zipline without
the need to extend the class. Simply instantiate any additional components
that you would like included in the zipline, and add them to the zipline
before invoking simulate.
"""
import mock
import pytz
from datetime import datetime, timedelta
from collections import defaultdict
from nose.tools import timed
import zipline.test.factory as factory
import zipline.util as qutil
import zipline.finance.risk as risk
import zipline.protocol as zp
import zipline.finance.performance as perf
import zipline.messaging as zmsg
from zipline.test.client import TestAlgorithm
from zipline.sources import SpecificEquityTrades
from zipline.finance.trading import TransactionSimulator, OrderDataSource, \
TradeSimulationClient
from zipline.simulator import AddressAllocator, Simulator
from zipline.monitor import Controller
class SimulatedTrading(object):
"""
Zipline with::
- _no_ data sources.
- Trade simulation client, which is available to send callbacks on
events and also accept orders to be simulated.
- An order data source, which will receive orders from the trade
simulation client, and feed them into the event stream to be
serialized and order alongside all other data source events.
- transaction simulation transformation, which receives the order
events and estimates a theoretical execution price and volume.
All components in this zipline are subject to heartbeat checks and
a control monitor, which can kill the entire zipline in the event of
exceptions in one of the components or an external request to end the
simulation.
"""
def __init__(self, trading_environment, allocator):
self.allocator = allocator
self.leased_sockets = []
self.trading_environment = trading_environment
self.sim_context = None
sockets = self.allocate_sockets(8)
addresses = {
'sync_address' : sockets[0],
'data_address' : sockets[1],
'feed_address' : sockets[2],
'merge_address' : sockets[3],
'result_address' : sockets[4],
'order_address' : sockets[5]
}
self.con = Controller(
sockets[6],
sockets[7],
logging = qutil.LOGGER
)
self.sim = Simulator(addresses)
self.trading_environment.frame_index = ['sid', 'volume', 'dt', \
'price', 'changed']
self.clients = {}
self.trading_client = TradeSimulationClient(self.trading_environment)
self.clients[self.trading_client.get_id] = self.trading_client
# setup all sources
self.sources = {}
self.order_source = OrderDataSource()
self.sources[self.order_source.get_id] = self.order_source
#setup transforms
self.transaction_sim = TransactionSimulator()
self.transforms = {}
self.transforms[self.transaction_sim.get_id] = self.transaction_sim
#register all components
self.sim.register_components([
self.trading_client,
self.order_source,
self.transaction_sim
])
self.sim.register_controller( self.con )
self.sim.on_done = self.shutdown()
self.started = False
def add_source(self, source):
assert isinstance(source, zmsg.DataSource)
self.check_started()
self.sim.register_components([source])
self.sources[source.get_id] = source
def add_transform(self, transform):
assert isinstance(transform, zmsg.BaseTransform)
self.check_started()
self.sim.register_components([transform])
self.sources[transform.get_id] = transform
def check_started(self):
if self.started:
raise ZiplineException("You cannot add sources after the \
simulation has begun.")
def get_cumulative_performance(self):
self.trading_client.perf.cumulative_performance.to_dict()
def allocate_sockets(self, n):
"""
Allocate sockets local to this line, track them so
we can gc after test run.
"""
assert isinstance(n, int)
assert n > 0
leased = self.allocator.lease(n)
self.leased_sockets.extend(leased)
return leased
def simulate(self, blocking=False):
self.started = True
self.sim_context = self.sim.simulate()
if blocking:
self.sim_context.join()
def shutdown(self):
self.allocator.reaquire(*self.leased_sockets)
#--------------------------------#
# Component property accessors #
#--------------------------------#
def get_positions(self):
"""
returns current positions as a dict. draws from the cumulative
performance period in the performance tracker.
"""
perf = self.trading_client.perf.cumulative_performance
positions = perf.get_positions()
return positions
class ZiplineException(Exception):
def __init__(msg):
Exception.__init__(msg)
+38
View File
@@ -5,6 +5,7 @@ import random
import zipline.util as qutil
import zipline.finance.risk as risk
import zipline.protocol as zp
from zipline.sources import SpecificEquityTrades
def load_market_data():
fp_bm = open("./zipline/test/benchmark.msgpack", "rb")
@@ -128,3 +129,40 @@ def create_returns_from_list(returns, start, trading_calendar):
current = current + one_day
return sorted(test_range, key=lambda(x):x.date)
def create_daily_trade_source(sids, trade_count, trading_environment):
"""
creates trade_count trades for each sid in sids list.
first trade will be on trading_environment.period_start, and daily
thereafter for each sid. Thus, two sids should result in two trades per
day.
Important side-effect: trading_environment.period_end will be modified
to match the day of the final trade.
"""
trade_history = []
for sid in sids:
price = [10.1] * trade_count
volume = [100] * trade_count
start_date = trading_environment.period_start
trade_time_increment = datetime.timedelta(days=1)
generated_trades = create_trade_history(
sid,
price,
volume,
start_date,
trade_time_increment,
trading_environment
)
trade_history.extend(generated_trades)
trade_history = sorted(trade_history, key=lambda(x): x.dt)
#set the trading environment's end to same dt as the last trade in the
#history.
trading_environment.period_end = trade_history[-1].dt
source = SpecificEquityTrades("flat", trade_history)
return source
+69 -160
View File
@@ -20,6 +20,7 @@ from zipline.finance.trading import TransactionSimulator, OrderDataSource, \
TradeSimulationClient
from zipline.simulator import AddressAllocator, Simulator
from zipline.monitor import Controller
from zipline.lines import SimulatedTrading
DEFAULT_TIMEOUT = 5 # seconds
@@ -34,31 +35,20 @@ class FinanceTestCase(TestCase):
self.benchmark_returns, self.treasury_curves = \
factory.load_market_data()
start = datetime.strptime("01/1/2006","%m/%d/%Y")
start = start.replace(tzinfo=pytz.utc)
self.trading_environment = risk.TradingEnvironment(
self.benchmark_returns,
self.treasury_curves
self.treasury_curves,
period_start = start,
capital_base = 100000.0
)
self.allocator = allocator
def allocate_sockets(self, n):
"""
Allocate sockets local to this test case, track them so
we can gc after test run.
"""
assert isinstance(n, int)
assert n > 0
leased = self.allocator.lease(n)
self.leased_sockets[self.id()].extend(leased)
return leased
@timed(DEFAULT_TIMEOUT)
def test_trade_feed_protocol(self):
# TODO: Perhaps something more self-documenting for variables names?
sid = 133
price = [10.0] * 4
volume = [100] * 4
@@ -164,172 +154,89 @@ class FinanceTestCase(TestCase):
# Just verify sending and receiving orders.
# --------------
# Allocate sockets for the simulator components
sockets = self.allocate_sockets(8)
addresses = {
'sync_address' : sockets[0],
'data_address' : sockets[1],
'feed_address' : sockets[2],
'merge_address' : sockets[3],
'result_address' : sockets[4],
'order_address' : sockets[5]
}
con = Controller(
sockets[6],
sockets[7],
logging = qutil.LOGGER
#
SID=133
sids = [133]
trade_count = 100
trade_source = factory.create_daily_trade_source(
sids,
trade_count,
self.trading_environment
)
sim = Simulator(addresses)
# Simulation Components
# ---------------------
# TODO: Perhaps something more self-documenting for variables names?
sid = 133
price = [10.1] * 16
volume = [100] * 16
start_date = datetime.strptime("02/1/2012","%m/%d/%Y")
start_date = start_date.replace(tzinfo=pytz.utc)
trade_time_increment = timedelta(days=1)
trade_history = factory.create_trade_history(
sid,
price,
volume,
start_date,
trade_time_increment,
self.trading_environment
)
set1 = SpecificEquityTrades("flat-133", trade_history)
self.trading_environment.period_start = trade_history[0].dt
self.trading_environment.period_end = trade_history[-1].dt
self.trading_environment.capital_base = 10000
self.trading_environment.frame_index = ['sid', 'volume', 'dt', \
'price', 'changed']
trading_client = TradeSimulationClient(self.trading_environment)
#client will send 10 orders for 100 shares of 133
test_algo = TestAlgorithm(133, 100, 10, trading_client)
order_source = OrderDataSource()
transaction_sim = TransactionSimulator()
sim.register_components([
trading_client,
order_source,
transaction_sim,
set1
])
sim.register_controller( con )
# Simulation
# ----------
sim_context = sim.simulate()
sim_context.join()
zipline = SimulatedTrading(
self.trading_environment,
self.allocator
)
zipline.add_source(trade_source)
order_amount = 100
order_count = 10
test_algo = TestAlgorithm(
SID,
order_amount,
order_count,
zipline.trading_client
)
zipline.simulate(blocking=True)
self.assertTrue(sim.ready())
self.assertFalse(sim.exception)
self.assertTrue(zipline.sim.ready())
self.assertFalse(zipline.sim.exception)
# TODO: Make more assertions about the final state of the components.
self.assertEqual(sim.feed.pending_messages(), 0, \
self.assertEqual(zipline.sim.feed.pending_messages(), 0, \
"The feed should be drained of all messages, found {n} remaining." \
.format(n=sim.feed.pending_messages()))
.format(n=zipline.sim.feed.pending_messages()))
@timed(DEFAULT_TIMEOUT)
def test_performance(self):
# verify order -> transaction -> portfolio position.
# --------------
# Allocate sockets for the simulator components
sockets = self.allocate_sockets(8)
addresses = {
'sync_address' : sockets[0],
'data_address' : sockets[1],
'feed_address' : sockets[2],
'merge_address' : sockets[3],
'result_address' : sockets[4],
'order_address' : sockets[5]
}
con = Controller(
sockets[6],
sockets[7],
logging = qutil.LOGGER
)
sim = Simulator(addresses)
# Simulation Components
# ---------------------
# TODO: Perhaps something more self-documenting for variables names?
# --------------
SID=133
sids = [133]
trade_count = 100
sid = 133
price = [10.1] * trade_count
volume = [100] * trade_count
start_date = datetime.strptime("02/1/2012","%m/%d/%Y")
start_date = start_date.replace(tzinfo=pytz.utc)
trade_time_increment = timedelta(days=1)
trade_history = factory.create_trade_history(
sid,
price,
volume,
start_date,
trade_time_increment,
self.trading_environment
trade_source = factory.create_daily_trade_source(
sids,
trade_count,
self.trading_environment
)
self.trading_environment.period_start = trade_history[0].dt
self.trading_environment.period_end = trade_history[-1].dt
self.trading_environment.capital_base = 10000
self.trading_environment.frame_index = ['sid', 'volume', 'dt', \
'price', 'changed']
set1 = SpecificEquityTrades("flat-133", trade_history)
#client sill send 10 orders for 100 shares of 133
trading_client = TradeSimulationClient(self.trading_environment)
test_algo = TestAlgorithm(133, 100, 10, trading_client)
order_source = OrderDataSource()
transaction_sim = TransactionSimulator()
sim.register_components([
trading_client,
order_source,
transaction_sim,
set1,
])
sim.register_controller( con )
# Simulation
# ----------
sim_context = sim.simulate()
sim_context.join()
zipline = SimulatedTrading(
self.trading_environment,
self.allocator
)
zipline.add_source(trade_source)
order_amount = 100
order_count = 25
test_algo = TestAlgorithm(
SID,
order_amount,
order_count,
zipline.trading_client
)
zipline.simulate(blocking=True)
self.assertEqual(
sim.feed.pending_messages(),
zipline.sim.feed.pending_messages(),
0,
"The feed should be drained of all messages, found {n} remaining." \
.format(n=sim.feed.pending_messages())
.format(n=zipline.sim.feed.pending_messages())
)
self.assertEqual(
sim.merge.pending_messages(),
zipline.sim.merge.pending_messages(),
0,
"The merge should be drained of all messages, found {n} remaining." \
.format(n=sim.merge.pending_messages())
.format(n=zipline.sim.merge.pending_messages())
)
self.assertEqual(
@@ -337,27 +244,29 @@ class FinanceTestCase(TestCase):
test_algo.incr,
"The test algorithm should send as many orders as specified.")
order_source = zipline.sources[zp.FINANCE_COMPONENT.ORDER_SOURCE]
self.assertEqual(
order_source.sent_count,
test_algo.count,
"The order source should have sent as many orders as the algo."
)
transaction_sim = zipline.transforms[zp.TRANSFORM_TYPE.TRANSACTION]
self.assertEqual(
transaction_sim.txn_count,
trading_client.perf.txn_count,
zipline.trading_client.perf.txn_count,
"The perf tracker should handle the same number of transactions \
as the simulator emits."
)
self.assertEqual(
len(trading_client.perf.cumulative_performance.positions),
len(zipline.get_positions()),
1,
"Portfolio should have one position."
)
self.assertEqual(
trading_client.perf.cumulative_performance.positions[133].sid,
133,
"Portfolio should have one position in 133."
zipline.get_positions()[SID]['sid'],
SID,
"Portfolio should have one position in " + str(SID)
)